Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
hapi
提交
6d876a00
H
hapi
项目概览
PaddlePaddle
/
hapi
通知
11
Star
2
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
4
列表
看板
标记
里程碑
合并请求
7
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
H
hapi
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
4
Issue
4
列表
看板
标记
里程碑
合并请求
7
合并请求
7
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d876a00
编写于
4月 07, 2020
作者:
Q
qingqing01
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Clear code
上级
44d49573
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
12 addition
and
32 deletion
+12
-32
cyclegan/data.py
cyclegan/data.py
+1
-22
cyclegan/infer.py
cyclegan/infer.py
+2
-1
cyclegan/test.py
cyclegan/test.py
+2
-1
cyclegan/train.py
cyclegan/train.py
+7
-8
未找到文件。
cyclegan/data.py
浏览文件 @
6d876a00
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -119,24 +119,3 @@ class ImagePool(object):
return
temp
else
:
return
image
if
__name__
==
'__main__'
:
place
=
fluid
.
CUDAPlace
(
0
)
#fluid.enable_dygraph(place)
dataset
=
DataA
(
shuffle
=
False
)
a_loader
=
fluid
.
io
.
DataLoader
(
dataset
,
feed_list
=
[
fluid
.
data
(
name
=
'im'
,
shape
=
[
None
,
2
,
2
,
],
dtype
=
'float32'
)
],
places
=
place
,
return_list
=
False
,
batch_size
=
2
)
for
data
in
a_loader
:
print
(
data
)
cyclegan/infer.py
浏览文件 @
6d876a00
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -102,6 +102,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"-s"
,
"--input_style"
,
type
=
str
,
default
=
'A'
,
help
=
"A or B"
)
FLAGS
=
parser
.
parse_args
()
print
(
FLAGS
)
check_gpu
(
str
.
lower
(
FLAGS
.
device
)
==
'gpu'
)
check_version
()
main
()
cyclegan/test.py
浏览文件 @
6d876a00
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -97,6 +97,7 @@ if __name__ == "__main__":
default
=
'checkpoint/199'
,
help
=
"The init model file of directory."
)
FLAGS
=
parser
.
parse_args
()
print
(
FLAGS
)
check_gpu
(
str
.
lower
(
FLAGS
.
device
)
==
'gpu'
)
check_version
()
main
()
cyclegan/train.py
浏览文件 @
6d876a00
# Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 20
20
PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -88,21 +88,15 @@ def main():
loader_A
=
fluid
.
io
.
DataLoader
(
data
.
DataA
(),
feed_list
=
[
x
.
forward
()
for
x
in
[
input_A
]]
if
not
FLAGS
.
dynamic
else
None
,
places
=
place
,
shuffle
=
True
,
return_list
=
True
,
use_buffer_reader
=
True
,
batch_size
=
FLAGS
.
batch_size
)
loader_B
=
fluid
.
io
.
DataLoader
(
data
.
DataB
(),
feed_list
=
[
x
.
forward
()
for
x
in
[
input_B
]]
if
not
FLAGS
.
dynamic
else
None
,
places
=
place
,
shuffle
=
True
,
return_list
=
True
,
use_buffer_reader
=
True
,
batch_size
=
FLAGS
.
batch_size
)
A_pool
=
data
.
ImagePool
()
...
...
@@ -136,7 +130,11 @@ if __name__ == "__main__":
parser
.
add_argument
(
"-d"
,
"--dynamic"
,
action
=
'store_false'
,
help
=
"Enable dygraph mode"
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
'gpu'
,
help
=
"device to use, gpu or cpu"
)
"-p"
,
"--device"
,
type
=
str
,
default
=
'gpu'
,
help
=
"device to use, gpu or cpu"
)
parser
.
add_argument
(
"-e"
,
"--epoch"
,
default
=
200
,
type
=
int
,
help
=
"Epoch number"
)
parser
.
add_argument
(
...
...
@@ -154,6 +152,7 @@ if __name__ == "__main__":
type
=
str
,
help
=
"checkpoint path to resume"
)
FLAGS
=
parser
.
parse_args
()
print
(
FLAGS
)
check_gpu
(
str
.
lower
(
FLAGS
.
device
)
==
'gpu'
)
check_version
()
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录