Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
e4c55582
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e4c55582
编写于
9月 10, 2020
作者:
Q
qingqing01
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' of
https://github.com/PaddlePaddle/hapi
into clean_code
上级
d2195325
224515a6
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
56 addition
and
47 deletion
+56
-47
cyclegan/cyclegan.py
cyclegan/cyclegan.py
+10
-14
cyclegan/infer.py
cyclegan/infer.py
+13
-7
cyclegan/test.py
cyclegan/test.py
+12
-7
cyclegan/train.py
cyclegan/train.py
+21
-19
未找到文件。
cyclegan/cyclegan.py
浏览文件 @
e4c55582
...
...
@@ -18,9 +18,8 @@ from __future__ import print_function
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.incubate.hapi.model
import
Model
from
paddle.incubate.hapi.loss
import
Loss
from
layers
import
ConvBN
,
DeConvBN
...
...
@@ -133,7 +132,7 @@ class NLayerDiscriminator(fluid.dygraph.Layer):
return
y
class
Generator
(
Model
):
class
Generator
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
input_channel
=
3
):
super
(
Generator
,
self
).
__init__
()
self
.
g
=
ResnetGenerator
(
input_channel
)
...
...
@@ -143,7 +142,7 @@ class Generator(Model):
return
fake
class
GeneratorCombine
(
Model
):
class
GeneratorCombine
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
g_AB
=
None
,
g_BA
=
None
,
d_A
=
None
,
d_B
=
None
,
is_train
=
True
):
super
(
GeneratorCombine
,
self
).
__init__
()
...
...
@@ -177,16 +176,15 @@ class GeneratorCombine(Model):
return
input_A
,
input_B
,
fake_A
,
fake_B
,
cyc_A
,
cyc_B
,
idt_A
,
idt_B
,
valid_A
,
valid_B
class
GLoss
(
Loss
):
class
GLoss
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
lambda_A
=
10.
,
lambda_B
=
10.
,
lambda_identity
=
0.5
):
super
(
GLoss
,
self
).
__init__
()
self
.
lambda_A
=
lambda_A
self
.
lambda_B
=
lambda_B
self
.
lambda_identity
=
lambda_identity
def
forward
(
self
,
outputs
,
labels
=
None
):
input_A
,
input_B
,
fake_A
,
fake_B
,
cyc_A
,
cyc_B
,
idt_A
,
idt_B
,
valid_A
,
valid_B
=
outputs
def
forward
(
self
,
input_A
,
input_B
,
fake_A
,
fake_B
,
cyc_A
,
cyc_B
,
idt_A
,
idt_B
,
valid_A
,
valid_B
):
def
mse
(
a
,
b
):
return
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
square
(
a
-
b
))
...
...
@@ -211,7 +209,7 @@ class GLoss(Loss):
return
loss
class
Discriminator
(
Model
):
class
Discriminator
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
input_channel
=
3
):
super
(
Discriminator
,
self
).
__init__
()
self
.
d
=
NLayerDiscriminator
(
input_channel
)
...
...
@@ -222,13 +220,11 @@ class Discriminator(Model):
return
pred_real
,
pred_fake
class
DLoss
(
Loss
):
class
DLoss
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
):
super
(
DLoss
,
self
).
__init__
()
def
forward
(
self
,
inputs
,
labels
=
None
):
pred_real
,
pred_fake
=
inputs
loss
=
fluid
.
layers
.
square
(
pred_fake
)
+
fluid
.
layers
.
square
(
pred_real
-
1.
)
def
forward
(
self
,
real
,
fake
):
loss
=
fluid
.
layers
.
square
(
fake
)
+
fluid
.
layers
.
square
(
real
-
1.
)
loss
=
fluid
.
layers
.
reduce_mean
(
loss
/
2.0
)
return
loss
cyclegan/infer.py
浏览文件 @
e4c55582
...
...
@@ -24,26 +24,32 @@ import argparse
from
PIL
import
Image
from
scipy.misc
import
imsave
import
paddle
import
paddle.fluid
as
fluid
from
paddle.
incubate.hapi.model
import
Model
,
Input
,
set_device
from
paddle.
static
import
InputSpec
as
Input
from
check
import
check_gpu
,
check_version
from
cyclegan
import
Generator
,
GeneratorCombine
def
main
():
place
=
set_device
(
FLAGS
.
device
)
place
=
paddle
.
set_device
(
FLAGS
.
device
)
fluid
.
enable_dygraph
(
place
)
if
FLAGS
.
dynamic
else
None
im_shape
=
[
-
1
,
3
,
256
,
256
]
input_A
=
Input
(
im_shape
,
'float32'
,
'input_A'
)
input_B
=
Input
(
im_shape
,
'float32'
,
'input_B'
)
# Generators
g_AB
=
Generator
()
g_BA
=
Generator
()
g
=
GeneratorCombine
(
g_AB
,
g_BA
,
is_train
=
False
)
im_shape
=
[
-
1
,
3
,
256
,
256
]
input_A
=
Input
(
im_shape
,
'float32'
,
'input_A'
)
input_B
=
Input
(
im_shape
,
'float32'
,
'input_B'
)
g
.
prepare
(
inputs
=
[
input_A
,
input_B
],
device
=
FLAGS
.
device
)
g
=
paddle
.
Model
(
GeneratorCombine
(
g_AB
,
g_BA
,
is_train
=
False
),
inputs
=
[
input_A
,
input_B
])
g
.
prepare
()
g
.
load
(
FLAGS
.
init_model
,
skip_mismatch
=
True
,
reset_optimizer
=
True
)
out_path
=
FLAGS
.
output
+
"/single"
...
...
cyclegan/test.py
浏览文件 @
e4c55582
...
...
@@ -21,8 +21,9 @@ import argparse
import
numpy
as
np
from
scipy.misc
import
imsave
import
paddle
import
paddle.fluid
as
fluid
from
paddle.
incubate.hapi.model
import
Model
,
Input
,
set_device
from
paddle.
static
import
InputSpec
as
Input
from
check
import
check_gpu
,
check_version
from
cyclegan
import
Generator
,
GeneratorCombine
...
...
@@ -30,18 +31,22 @@ import data as data
def
main
():
place
=
set_device
(
FLAGS
.
device
)
place
=
paddle
.
set_device
(
FLAGS
.
device
)
fluid
.
enable_dygraph
(
place
)
if
FLAGS
.
dynamic
else
None
im_shape
=
[
-
1
,
3
,
256
,
256
]
input_A
=
Input
(
im_shape
,
'float32'
,
'input_A'
)
input_B
=
Input
(
im_shape
,
'float32'
,
'input_B'
)
# Generators
g_AB
=
Generator
()
g_BA
=
Generator
()
g
=
GeneratorCombine
(
g_AB
,
g_BA
,
is_train
=
False
)
g
=
paddle
.
Model
(
GeneratorCombine
(
g_AB
,
g_BA
,
is_train
=
False
),
inputs
=
[
input_A
,
input_B
])
im_shape
=
[
-
1
,
3
,
256
,
256
]
input_A
=
Input
(
im_shape
,
'float32'
,
'input_A'
)
input_B
=
Input
(
im_shape
,
'float32'
,
'input_B'
)
g
.
prepare
(
inputs
=
[
input_A
,
input_B
],
device
=
FLAGS
.
device
)
g
.
prepare
()
g
.
load
(
FLAGS
.
init_model
,
skip_mismatch
=
True
,
reset_optimizer
=
True
)
if
not
os
.
path
.
exists
(
FLAGS
.
output
):
...
...
cyclegan/train.py
浏览文件 @
e4c55582
...
...
@@ -24,7 +24,7 @@ import time
import
paddle
import
paddle.fluid
as
fluid
from
paddle.
incubate.hapi.model
import
Model
,
Input
,
set_device
from
paddle.
static
import
InputSpec
as
Input
from
check
import
check_gpu
,
check_version
from
cyclegan
import
Generator
,
Discriminator
,
GeneratorCombine
,
GLoss
,
DLoss
...
...
@@ -48,18 +48,29 @@ def opt(parameters):
def
main
():
place
=
set_device
(
FLAGS
.
device
)
place
=
paddle
.
set_device
(
FLAGS
.
device
)
fluid
.
enable_dygraph
(
place
)
if
FLAGS
.
dynamic
else
None
im_shape
=
[
None
,
3
,
256
,
256
]
input_A
=
Input
(
im_shape
,
'float32'
,
'input_A'
)
input_B
=
Input
(
im_shape
,
'float32'
,
'input_B'
)
fake_A
=
Input
(
im_shape
,
'float32'
,
'fake_A'
)
fake_B
=
Input
(
im_shape
,
'float32'
,
'fake_B'
)
# Generators
g_AB
=
Generator
()
g_BA
=
Generator
()
# Discriminators
d_A
=
Discriminator
()
d_B
=
Discriminator
()
g
=
GeneratorCombine
(
g_AB
,
g_BA
,
d_A
,
d_B
)
g
=
paddle
.
Model
(
GeneratorCombine
(
g_AB
,
g_BA
,
d_A
,
d_B
),
inputs
=
[
input_A
,
input_B
])
g_AB
=
paddle
.
Model
(
g_AB
,
[
input_A
])
g_BA
=
paddle
.
Model
(
g_BA
,
[
input_B
])
# Discriminators
d_A
=
paddle
.
Model
(
d_A
,
[
input_B
,
fake_B
])
d_B
=
paddle
.
Model
(
d_B
,
[
input_A
,
fake_A
])
da_params
=
d_A
.
parameters
()
db_params
=
d_B
.
parameters
()
...
...
@@ -69,21 +80,12 @@ def main():
db_optimizer
=
opt
(
db_params
)
g_optimizer
=
opt
(
g_params
)
im_shape
=
[
None
,
3
,
256
,
256
]
input_A
=
Input
(
im_shape
,
'float32'
,
'input_A'
)
input_B
=
Input
(
im_shape
,
'float32'
,
'input_B'
)
fake_A
=
Input
(
im_shape
,
'float32'
,
'fake_A'
)
fake_B
=
Input
(
im_shape
,
'float32'
,
'fake_B'
)
g_AB
.
prepare
(
inputs
=
[
input_A
],
device
=
FLAGS
.
device
)
g_BA
.
prepare
(
inputs
=
[
input_B
],
device
=
FLAGS
.
device
)
g_AB
.
prepare
()
g_BA
.
prepare
()
g
.
prepare
(
g_optimizer
,
GLoss
(),
inputs
=
[
input_A
,
input_B
],
device
=
FLAGS
.
device
)
d_A
.
prepare
(
da_optimizer
,
DLoss
(),
inputs
=
[
input_B
,
fake_B
],
device
=
FLAGS
.
device
)
d_B
.
prepare
(
db_optimizer
,
DLoss
(),
inputs
=
[
input_A
,
fake_A
],
device
=
FLAGS
.
device
)
g
.
prepare
(
g_optimizer
,
GLoss
())
d_A
.
prepare
(
da_optimizer
,
DLoss
())
d_B
.
prepare
(
db_optimizer
,
DLoss
())
if
FLAGS
.
resume
:
g
.
load
(
FLAGS
.
resume
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录