Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5f6c4af3
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
5f6c4af3
编写于
12月 21, 2016
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Try to read data in mnist
上级
ad93b8f9
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
62 addition
and
25 deletion
+62
-25
demo/mnist/api_train.py
demo/mnist/api_train.py
+29
-0
demo/mnist/mnist_provider.py
demo/mnist/mnist_provider.py
+3
-25
demo/mnist/mnist_util.py
demo/mnist/mnist_util.py
+30
-0
未找到文件。
demo/mnist/api_train.py
浏览文件 @
5f6c4af3
import
py_paddle.swig_paddle
as
api
import
py_paddle.swig_paddle
as
api
from
py_paddle
import
DataProviderConverter
import
paddle.trainer.PyDataProvider2
as
dp
import
paddle.trainer.config_parser
import
paddle.trainer.config_parser
import
numpy
as
np
import
numpy
as
np
from
mnist_util
import
read_from_mnist
def
init_parameter
(
network
):
def
init_parameter
(
network
):
...
@@ -13,6 +16,22 @@ def init_parameter(network):
...
@@ -13,6 +16,22 @@ def init_parameter(network):
array
[
i
]
=
np
.
random
.
uniform
(
-
1.0
,
1.0
)
array
[
i
]
=
np
.
random
.
uniform
(
-
1.0
,
1.0
)
def
generator_to_batch
(
generator
,
batch_size
):
ret_val
=
list
()
for
each_item
in
generator
:
ret_val
.
append
(
each_item
)
if
len
(
ret_val
)
==
batch_size
:
yield
ret_val
ret_val
=
list
()
if
len
(
ret_val
)
!=
0
:
yield
ret_val
def
input_order_converter
(
generator
):
for
each_item
in
generator
:
yield
each_item
[
'pixel'
],
each_item
[
'label'
]
def
main
():
def
main
():
api
.
initPaddle
(
"-use_gpu=false"
,
"-trainer_count=4"
)
# use 4 cpu cores
api
.
initPaddle
(
"-use_gpu=false"
,
"-trainer_count=4"
)
# use 4 cpu cores
config
=
paddle
.
trainer
.
config_parser
.
parse_config
(
config
=
paddle
.
trainer
.
config_parser
.
parse_config
(
...
@@ -30,10 +49,20 @@ def main():
...
@@ -30,10 +49,20 @@ def main():
updater
=
api
.
ParameterUpdater
.
createLocalUpdater
(
opt_config
)
updater
=
api
.
ParameterUpdater
.
createLocalUpdater
(
opt_config
)
assert
isinstance
(
updater
,
api
.
ParameterUpdater
)
assert
isinstance
(
updater
,
api
.
ParameterUpdater
)
updater
.
init
(
m
)
updater
.
init
(
m
)
converter
=
DataProviderConverter
(
input_types
=
[
dp
.
dense_vector
(
784
),
dp
.
integer_value
(
10
)])
train_file
=
'./data/raw_data/train'
m
.
start
()
m
.
start
()
for
_
in
xrange
(
100
):
for
_
in
xrange
(
100
):
updater
.
startPass
()
updater
.
startPass
()
train_data_generator
=
input_order_converter
(
read_from_mnist
(
train_file
))
for
data_batch
in
generator_to_batch
(
train_data_generator
,
128
):
inArgs
=
converter
(
data_batch
)
updater
.
finishPass
()
updater
.
finishPass
()
...
...
demo/mnist/mnist_provider.py
浏览文件 @
5f6c4af3
from
paddle.trainer.PyDataProvider2
import
*
from
paddle.trainer.PyDataProvider2
import
*
import
numpy
from
mnist_util
import
read_from_mnist
# Define a py data provider
# Define a py data provider
...
@@ -8,27 +8,5 @@ import numpy
...
@@ -8,27 +8,5 @@ import numpy
'label'
:
integer_value
(
10
)},
'label'
:
integer_value
(
10
)},
cache
=
CacheType
.
CACHE_PASS_IN_MEM
)
cache
=
CacheType
.
CACHE_PASS_IN_MEM
)
def
process
(
settings
,
filename
):
# settings is not used currently.
def
process
(
settings
,
filename
):
# settings is not used currently.
imgf
=
filename
+
"-images-idx3-ubyte"
for
each
in
read_from_mnist
(
filename
):
labelf
=
filename
+
"-labels-idx1-ubyte"
yield
each
f
=
open
(
imgf
,
"rb"
)
l
=
open
(
labelf
,
"rb"
)
f
.
read
(
16
)
l
.
read
(
8
)
# Define number of samples for train/test
if
"train"
in
filename
:
n
=
60000
else
:
n
=
10000
images
=
numpy
.
fromfile
(
f
,
'ubyte'
,
count
=
n
*
28
*
28
).
reshape
((
n
,
28
*
28
)).
astype
(
'float32'
)
images
=
images
/
255.0
*
2.0
-
1.0
labels
=
numpy
.
fromfile
(
l
,
'ubyte'
,
count
=
n
).
astype
(
"int"
)
for
i
in
xrange
(
n
):
yield
{
"pixel"
:
images
[
i
,
:],
'label'
:
labels
[
i
]}
f
.
close
()
l
.
close
()
demo/mnist/mnist_util.py
0 → 100644
浏览文件 @
5f6c4af3
import
numpy
__all__
=
[
'read_from_mnist'
]
def
read_from_mnist
(
filename
):
imgf
=
filename
+
"-images-idx3-ubyte"
labelf
=
filename
+
"-labels-idx1-ubyte"
f
=
open
(
imgf
,
"rb"
)
l
=
open
(
labelf
,
"rb"
)
f
.
read
(
16
)
l
.
read
(
8
)
# Define number of samples for train/test
if
"train"
in
filename
:
n
=
60000
else
:
n
=
10000
images
=
numpy
.
fromfile
(
f
,
'ubyte'
,
count
=
n
*
28
*
28
).
reshape
((
n
,
28
*
28
)).
astype
(
'float32'
)
images
=
images
/
255.0
*
2.0
-
1.0
labels
=
numpy
.
fromfile
(
l
,
'ubyte'
,
count
=
n
).
astype
(
"int"
)
for
i
in
xrange
(
n
):
yield
{
"pixel"
:
images
[
i
,
:],
'label'
:
labels
[
i
]}
f
.
close
()
l
.
close
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录