Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
d8aada07
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d8aada07
编写于
11月 14, 2016
作者:
W
wangyang59
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added cifar data into dema/gan
上级
fb0d80d5
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
82 addition
and
40 deletion
+82
-40
demo/gan/.gitignore
demo/gan/.gitignore
+3
-1
demo/gan/data/download_cifar.sh
demo/gan/data/download_cifar.sh
+18
-0
demo/gan/gan_conf_image.py
demo/gan/gan_conf_image.py
+7
-4
demo/gan/gan_trainer_image.py
demo/gan/gan_trainer_image.py
+54
-35
未找到文件。
demo/gan/.gitignore
浏览文件 @
d8aada07
...
@@ -2,5 +2,7 @@ output/
...
@@ -2,5 +2,7 @@ output/
*.png
*.png
.pydevproject
.pydevproject
.project
.project
train.log
*.log
*.pyc
data/raw_data/
data/raw_data/
data/cifar-10-batches-py/
demo/gan/data/download_cifar.sh
0 → 100755
浏览文件 @
d8aada07
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set
-e
wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
tar
zxf cifar-10-python.tar.gz
rm
cifar-10-python.tar.gz
demo/gan/gan_conf_image.py
浏览文件 @
d8aada07
...
@@ -12,10 +12,9 @@
...
@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
from
paddle.trainer_config_helpers
import
*
from
paddle.trainer_config_helpers
import
*
from
paddle.trainer_config_helpers.activations
import
LinearActivation
from
numpy.distutils.system_info
import
tmp
mode
=
get_config_arg
(
"mode"
,
str
,
"generator"
)
mode
=
get_config_arg
(
"mode"
,
str
,
"generator"
)
dataSource
=
get_config_arg
(
"data"
,
str
,
"mnist"
)
assert
mode
in
set
([
"generator"
,
assert
mode
in
set
([
"generator"
,
"discriminator"
,
"discriminator"
,
"generator_training"
,
"generator_training"
,
...
@@ -30,8 +29,12 @@ print('mode=%s' % mode)
...
@@ -30,8 +29,12 @@ print('mode=%s' % mode)
noise_dim
=
100
noise_dim
=
100
gf_dim
=
64
gf_dim
=
64
df_dim
=
64
df_dim
=
64
sample_dim
=
28
# image dim
if
dataSource
==
"mnist"
:
c_dim
=
1
# image color
sample_dim
=
28
# image dim
c_dim
=
1
# image color
else
:
sample_dim
=
32
c_dim
=
3
s2
,
s4
=
int
(
sample_dim
/
2
),
int
(
sample_dim
/
4
),
s2
,
s4
=
int
(
sample_dim
/
2
),
int
(
sample_dim
/
4
),
s8
,
s16
=
int
(
sample_dim
/
8
),
int
(
sample_dim
/
16
)
s8
,
s16
=
int
(
sample_dim
/
8
),
int
(
sample_dim
/
16
)
...
...
demo/gan/gan_trainer_image.py
浏览文件 @
d8aada07
...
@@ -16,31 +16,13 @@ import argparse
...
@@ -16,31 +16,13 @@ import argparse
import
itertools
import
itertools
import
random
import
random
import
numpy
import
numpy
import
cPickle
import
sys
,
os
,
gc
import
sys
,
os
,
gc
from
PIL
import
Image
from
PIL
import
Image
from
paddle.trainer.config_parser
import
parse_config
from
paddle.trainer.config_parser
import
parse_config
from
paddle.trainer.config_parser
import
logger
from
paddle.trainer.config_parser
import
logger
import
py_paddle.swig_paddle
as
api
import
py_paddle.swig_paddle
as
api
from
py_paddle
import
DataProviderConverter
import
matplotlib.pyplot
as
plt
def
plot2DScatter
(
data
,
outputfile
):
# Generate some test data
x
=
data
[:,
0
]
y
=
data
[:,
1
]
print
"The mean vector is %s"
%
numpy
.
mean
(
data
,
0
)
print
"The std vector is %s"
%
numpy
.
std
(
data
,
0
)
heatmap
,
xedges
,
yedges
=
numpy
.
histogram2d
(
x
,
y
,
bins
=
50
)
extent
=
[
xedges
[
0
],
xedges
[
-
1
],
yedges
[
0
],
yedges
[
-
1
]]
plt
.
clf
()
plt
.
scatter
(
x
,
y
)
# plt.show()
plt
.
savefig
(
outputfile
,
bbox_inches
=
'tight'
)
def
CHECK_EQ
(
a
,
b
):
def
CHECK_EQ
(
a
,
b
):
assert
a
==
b
,
"a=%s, b=%s"
%
(
a
,
b
)
assert
a
==
b
,
"a=%s, b=%s"
%
(
a
,
b
)
...
@@ -94,18 +76,39 @@ def load_mnist_data(imageFile):
...
@@ -94,18 +76,39 @@ def load_mnist_data(imageFile):
f
.
close
()
f
.
close
()
return
data
return
data
def
load_cifar_data
(
cifar_path
):
batch_size
=
10000
data
=
numpy
.
zeros
((
5
*
batch_size
,
32
*
32
*
3
),
dtype
=
"float32"
)
for
i
in
range
(
1
,
6
):
file
=
cifar_path
+
"/data_batch_"
+
str
(
i
)
fo
=
open
(
file
,
'rb'
)
dict
=
cPickle
.
load
(
fo
)
fo
.
close
()
data
[(
i
-
1
)
*
batch_size
:(
i
*
batch_size
),
:]
=
dict
[
"data"
]
data
=
data
/
255.0
*
2.0
-
1.0
return
data
def
merge
(
images
,
size
):
def
merge
(
images
,
size
):
h
,
w
=
28
,
28
if
images
.
shape
[
1
]
==
28
*
28
:
img
=
numpy
.
zeros
((
h
*
size
[
0
],
w
*
size
[
1
]))
h
,
w
,
c
=
28
,
28
,
1
else
:
h
,
w
,
c
=
32
,
32
,
3
img
=
numpy
.
zeros
((
h
*
size
[
0
],
w
*
size
[
1
],
c
))
for
idx
in
xrange
(
size
[
0
]
*
size
[
1
]):
for
idx
in
xrange
(
size
[
0
]
*
size
[
1
]):
i
=
idx
%
size
[
1
]
i
=
idx
%
size
[
1
]
j
=
idx
//
size
[
1
]
j
=
idx
//
size
[
1
]
img
[
j
*
h
:
j
*
h
+
h
,
i
*
w
:
i
*
w
+
w
]
=
(
images
[
idx
,
:].
reshape
((
h
,
w
))
+
1.0
)
/
2.0
*
255.0
#img[j*h:j*h+h, i*w:i*w+w, :] = (images[idx, :].reshape((h, w, c), order="F") + 1.0) / 2.0 * 255.0
return
img
img
[
j
*
h
:
j
*
h
+
h
,
i
*
w
:
i
*
w
+
w
,
:]
=
\
((
images
[
idx
,
:].
reshape
((
h
,
w
,
c
),
order
=
"F"
).
transpose
(
1
,
0
,
2
)
+
1.0
)
/
2.0
*
255.0
)
return
img
.
astype
(
'uint8'
)
def
saveImages
(
images
,
path
):
def
saveImages
(
images
,
path
):
merged_img
=
merge
(
images
,
[
8
,
8
])
merged_img
=
merge
(
images
,
[
8
,
8
])
im
=
Image
.
fromarray
(
merged_img
).
convert
(
'RGB'
)
if
merged_img
.
shape
[
2
]
==
1
:
im
=
Image
.
fromarray
(
numpy
.
squeeze
(
merged_img
)).
convert
(
'RGB'
)
else
:
im
=
Image
.
fromarray
(
merged_img
,
mode
=
"RGB"
)
im
.
save
(
path
)
im
.
save
(
path
)
def
get_real_samples
(
batch_size
,
data_np
):
def
get_real_samples
(
batch_size
,
data_np
):
...
@@ -115,9 +118,9 @@ def get_real_samples(batch_size, data_np):
...
@@ -115,9 +118,9 @@ def get_real_samples(batch_size, data_np):
def
get_noise
(
batch_size
,
noise_dim
):
def
get_noise
(
batch_size
,
noise_dim
):
return
numpy
.
random
.
normal
(
size
=
(
batch_size
,
noise_dim
)).
astype
(
'float32'
)
return
numpy
.
random
.
normal
(
size
=
(
batch_size
,
noise_dim
)).
astype
(
'float32'
)
def
get_sample_noise
(
batch_size
):
def
get_sample_noise
(
batch_size
,
sample_dim
):
return
numpy
.
random
.
normal
(
size
=
(
batch_size
,
28
*
28
),
return
numpy
.
random
.
normal
(
size
=
(
batch_size
,
sample_dim
),
scale
=
0.1
).
astype
(
'float32'
)
scale
=
0.
0
1
).
astype
(
'float32'
)
def
get_fake_samples
(
generator_machine
,
batch_size
,
noise
):
def
get_fake_samples
(
generator_machine
,
batch_size
,
noise
):
gen_inputs
=
api
.
Arguments
.
createArguments
(
1
)
gen_inputs
=
api
.
Arguments
.
createArguments
(
1
)
...
@@ -177,15 +180,31 @@ def get_layer_size(model_conf, layer_name):
...
@@ -177,15 +180,31 @@ def get_layer_size(model_conf, layer_name):
def
main
():
def
main
():
api
.
initPaddle
(
'--use_gpu=1'
,
'--dot_period=10'
,
'--log_period=100'
)
parser
=
argparse
.
ArgumentParser
()
gen_conf
=
parse_config
(
"gan_conf_image.py"
,
"mode=generator_training"
)
parser
.
add_argument
(
"-d"
,
"--dataSource"
,
help
=
"mnist or cifar"
)
dis_conf
=
parse_config
(
"gan_conf_image.py"
,
"mode=discriminator_training"
)
parser
.
add_argument
(
"--useGpu"
,
default
=
"1"
,
generator_conf
=
parse_config
(
"gan_conf_image.py"
,
"mode=generator"
)
help
=
"1 means use gpu for training"
)
args
=
parser
.
parse_args
()
dataSource
=
args
.
dataSource
useGpu
=
args
.
useGpu
assert
dataSource
in
[
"mnist"
,
"cifar"
]
assert
useGpu
in
[
"0"
,
"1"
]
api
.
initPaddle
(
'--use_gpu='
+
useGpu
,
'--dot_period=10'
,
'--log_period=100'
)
gen_conf
=
parse_config
(
"gan_conf_image.py"
,
"mode=generator_training,data="
+
dataSource
)
dis_conf
=
parse_config
(
"gan_conf_image.py"
,
"mode=discriminator_training,data="
+
dataSource
)
generator_conf
=
parse_config
(
"gan_conf_image.py"
,
"mode=generator,data="
+
dataSource
)
batch_size
=
dis_conf
.
opt_config
.
batch_size
batch_size
=
dis_conf
.
opt_config
.
batch_size
noise_dim
=
get_layer_size
(
gen_conf
.
model_config
,
"noise"
)
noise_dim
=
get_layer_size
(
gen_conf
.
model_config
,
"noise"
)
sample_dim
=
get_layer_size
(
dis_conf
.
model_config
,
"sample"
)
sample_dim
=
get_layer_size
(
dis_conf
.
model_config
,
"sample"
)
if
dataSource
==
"mnist"
:
data_np
=
load_mnist_data
(
"./data/raw_data/train-images-idx3-ubyte"
)
data_np
=
load_mnist_data
(
"./data/raw_data/train-images-idx3-ubyte"
)
else
:
data_np
=
load_cifar_data
(
"./data/cifar-10-batches-py/"
)
if
not
os
.
path
.
exists
(
"./%s_samples/"
%
dataSource
):
os
.
makedirs
(
"./%s_samples/"
%
dataSource
)
# this create a gradient machine for discriminator
# this create a gradient machine for discriminator
dis_training_machine
=
api
.
GradientMachine
.
createFromConfigProto
(
dis_training_machine
=
api
.
GradientMachine
.
createFromConfigProto
(
...
@@ -224,12 +243,12 @@ def main():
...
@@ -224,12 +243,12 @@ def main():
# generator_machine, batch_size, noise_dim, sample_dim)
# generator_machine, batch_size, noise_dim, sample_dim)
# dis_loss = get_training_loss(dis_training_machine, data_batch_dis)
# dis_loss = get_training_loss(dis_training_machine, data_batch_dis)
noise
=
get_noise
(
batch_size
,
noise_dim
)
noise
=
get_noise
(
batch_size
,
noise_dim
)
sample_noise
=
get_sample_noise
(
batch_size
)
sample_noise
=
get_sample_noise
(
batch_size
,
sample_dim
)
data_batch_dis_pos
=
prepare_discriminator_data_batch_pos
(
data_batch_dis_pos
=
prepare_discriminator_data_batch_pos
(
batch_size
,
data_np
,
sample_noise
)
batch_size
,
data_np
,
sample_noise
)
dis_loss_pos
=
get_training_loss
(
dis_training_machine
,
data_batch_dis_pos
)
dis_loss_pos
=
get_training_loss
(
dis_training_machine
,
data_batch_dis_pos
)
sample_noise
=
get_sample_noise
(
batch_size
)
sample_noise
=
get_sample_noise
(
batch_size
,
sample_dim
)
data_batch_dis_neg
=
prepare_discriminator_data_batch_neg
(
data_batch_dis_neg
=
prepare_discriminator_data_batch_neg
(
generator_machine
,
batch_size
,
noise
,
sample_noise
)
generator_machine
,
batch_size
,
noise
,
sample_noise
)
dis_loss_neg
=
get_training_loss
(
dis_training_machine
,
data_batch_dis_neg
)
dis_loss_neg
=
get_training_loss
(
dis_training_machine
,
data_batch_dis_neg
)
...
@@ -271,7 +290,7 @@ def main():
...
@@ -271,7 +290,7 @@ def main():
fake_samples
=
get_fake_samples
(
generator_machine
,
batch_size
,
noise
)
fake_samples
=
get_fake_samples
(
generator_machine
,
batch_size
,
noise
)
saveImages
(
fake_samples
,
"
train_pass%s.png"
%
train_pass
)
saveImages
(
fake_samples
,
"
./%s_samples/train_pass%s.png"
%
(
dataSource
,
train_pass
)
)
dis_trainer
.
finishTrain
()
dis_trainer
.
finishTrain
()
gen_trainer
.
finishTrain
()
gen_trainer
.
finishTrain
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录