Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
409a5774
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
409a5774
编写于
12月 21, 2016
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Complete a very simple mnist demo.
上级
eaba2e2e
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
99 addition
and
9 deletion
+99
-9
demo/mnist/api_train.py
demo/mnist/api_train.py
+99
-9
未找到文件。
demo/mnist/api_train.py
浏览文件 @
409a5774
"""
A very basic example for how to use current Raw SWIG API to train mnist network.
Current implementation uses Raw SWIG, which means the API call is directly
\
passed to C++ side of Paddle.
The user api could be simpler and carefully designed.
"""
import
py_paddle.swig_paddle
as
api
import
py_paddle.swig_paddle
as
api
from
py_paddle
import
DataProviderConverter
from
py_paddle
import
DataProviderConverter
import
paddle.trainer.PyDataProvider2
as
dp
import
paddle.trainer.PyDataProvider2
as
dp
import
paddle.trainer.config_parser
import
paddle.trainer.config_parser
import
numpy
as
np
import
numpy
as
np
import
random
from
mnist_util
import
read_from_mnist
from
mnist_util
import
read_from_mnist
...
@@ -27,6 +36,18 @@ def generator_to_batch(generator, batch_size):
...
@@ -27,6 +36,18 @@ def generator_to_batch(generator, batch_size):
yield
ret_val
yield
ret_val
class
BatchPool
(
object
):
def
__init__
(
self
,
generator
,
batch_size
):
self
.
data
=
list
(
generator
)
self
.
batch_size
=
batch_size
def
__call__
(
self
):
random
.
shuffle
(
self
.
data
)
for
offset
in
xrange
(
0
,
len
(
self
.
data
),
self
.
batch_size
):
limit
=
min
(
offset
+
self
.
batch_size
,
len
(
self
.
data
))
yield
self
.
data
[
offset
:
limit
]
def
input_order_converter
(
generator
):
def
input_order_converter
(
generator
):
for
each_item
in
generator
:
for
each_item
in
generator
:
yield
each_item
[
'pixel'
],
each_item
[
'label'
]
yield
each_item
[
'pixel'
],
each_item
[
'label'
]
...
@@ -37,46 +58,115 @@ def main():
...
@@ -37,46 +58,115 @@ def main():
config
=
paddle
.
trainer
.
config_parser
.
parse_config
(
config
=
paddle
.
trainer
.
config_parser
.
parse_config
(
'simple_mnist_network.py'
,
''
)
'simple_mnist_network.py'
,
''
)
# get enable_types for each optimizer.
# enable_types = [value, gradient, momentum, etc]
# For each optimizer(SGD, Adam), GradientMachine should enable different
# buffers.
opt_config
=
api
.
OptimizationConfig
.
createFromProto
(
config
.
opt_config
)
opt_config
=
api
.
OptimizationConfig
.
createFromProto
(
config
.
opt_config
)
_temp_optimizer_
=
api
.
ParameterOptimizer
.
create
(
opt_config
)
_temp_optimizer_
=
api
.
ParameterOptimizer
.
create
(
opt_config
)
enable_types
=
_temp_optimizer_
.
getParameterTypes
()
enable_types
=
_temp_optimizer_
.
getParameterTypes
()
# Create Simple Gradient Machine.
m
=
api
.
GradientMachine
.
createFromConfigProto
(
m
=
api
.
GradientMachine
.
createFromConfigProto
(
config
.
model_config
,
api
.
CREATE_MODE_NORMAL
,
enable_types
)
config
.
model_config
,
api
.
CREATE_MODE_NORMAL
,
enable_types
)
# This type check is not useful. Only enable type hint in IDE.
# Such as PyCharm
assert
isinstance
(
m
,
api
.
GradientMachine
)
assert
isinstance
(
m
,
api
.
GradientMachine
)
# Initialize Parameter by numpy.
init_parameter
(
network
=
m
)
init_parameter
(
network
=
m
)
# Create Local Updater. Local means not run in cluster.
# For a cluster training, here we can change to createRemoteUpdater
# in future.
updater
=
api
.
ParameterUpdater
.
createLocalUpdater
(
opt_config
)
updater
=
api
.
ParameterUpdater
.
createLocalUpdater
(
opt_config
)
assert
isinstance
(
updater
,
api
.
ParameterUpdater
)
assert
isinstance
(
updater
,
api
.
ParameterUpdater
)
# Initialize ParameterUpdater.
updater
.
init
(
m
)
updater
.
init
(
m
)
# DataProvider Converter is a utility convert Python Object to Paddle C++
# Input. The input format is as same as Paddle's DataProvider.
converter
=
DataProviderConverter
(
converter
=
DataProviderConverter
(
input_types
=
[
dp
.
dense_vector
(
784
),
dp
.
integer_value
(
10
)])
input_types
=
[
dp
.
dense_vector
(
784
),
dp
.
integer_value
(
10
)])
train_file
=
'./data/raw_data/train'
train_file
=
'./data/raw_data/train'
test_file
=
'./data/raw_data/t10k'
# start gradient machine.
# the gradient machine must be started before invoke forward/backward.
# not just for training, but also for inference.
m
.
start
()
m
.
start
()
for
_
in
xrange
(
100
):
# evaluator can print error rate, etc. It is a C++ class.
batch_evaluator
=
m
.
makeEvaluator
()
test_evaluator
=
m
.
makeEvaluator
()
# Get Train Data.
# TrainData will stored in a data pool. Currently implementation is not care
# about memory, speed. Just a very naive implementation.
train_data_generator
=
input_order_converter
(
read_from_mnist
(
train_file
))
train_data
=
BatchPool
(
train_data_generator
,
128
)
# outArgs is Neural Network forward result. Here is not useful, just passed
# to gradient_machine.forward
outArgs
=
api
.
Arguments
.
createArguments
(
0
)
for
pass_id
in
xrange
(
2
):
# we train 2 passes.
updater
.
startPass
()
updater
.
startPass
()
outArgs
=
api
.
Arguments
.
createArguments
(
0
)
train_data_generator
=
input_order_converter
(
read_from_mnist
(
train_file
))
for
batch_id
,
data_batch
in
enumerate
(
generator_to_batch
(
train_data_generator
,
2048
)):
trainRole
=
updater
.
startBatch
(
len
(
data_batch
))
for
batch_id
,
data_batch
in
enumerate
(
train_data
()):
# data_batch is input images.
# here, for online learning, we could get data_batch from network.
# Start update one batch.
pass_type
=
updater
.
startBatch
(
len
(
data_batch
))
# Start BatchEvaluator.
# batch_evaluator can be used between start/finish.
batch_evaluator
.
start
()
# A callback when backward.
# It is used for updating weight values vy calculated Gradient.
def
updater_callback
(
param
):
def
updater_callback
(
param
):
updater
.
update
(
param
)
updater
.
update
(
param
)
# forwardBackward is a shortcut for forward and backward.
# It is sometimes faster than invoke forward/backward separately,
# because in GradientMachine, it may be async.
m
.
forwardBackward
(
m
.
forwardBackward
(
converter
(
data_batch
),
outArgs
,
trainRol
e
,
updater_callback
)
converter
(
data_batch
),
outArgs
,
pass_typ
e
,
updater_callback
)
# Get cost. We use numpy to calculate total cost for this batch.
cost_vec
=
outArgs
.
getSlotValue
(
0
)
cost_vec
=
outArgs
.
getSlotValue
(
0
)
cost_vec
=
cost_vec
.
copyToNumpyMat
()
cost_vec
=
cost_vec
.
copyToNumpyMat
()
cost
=
cost_vec
.
sum
()
/
len
(
data_batch
)
cost
=
cost_vec
.
sum
()
/
len
(
data_batch
)
print
'Batch id'
,
batch_id
,
'with cost='
,
cost
# Make evaluator works.
m
.
eval
(
batch_evaluator
)
# Print logs.
print
'Pass id'
,
pass_id
,
'Batch id'
,
batch_id
,
'with cost='
,
\
cost
,
batch_evaluator
batch_evaluator
.
finish
()
# Finish batch.
# * will clear gradient.
# * ensure all values should be updated.
updater
.
finishBatch
(
cost
)
updater
.
finishBatch
(
cost
)
# testing stage. use test data set to test current network.
test_evaluator
.
start
()
test_data_generator
=
input_order_converter
(
read_from_mnist
(
test_file
))
for
data_batch
in
generator_to_batch
(
test_data_generator
,
128
):
# in testing stage, only forward is needed.
m
.
forward
(
converter
(
data_batch
),
outArgs
,
api
.
PASS_TEST
)
m
.
eval
(
test_evaluator
)
# print error rate for test data set
print
'Pass'
,
pass_id
,
' test evaluator: '
,
test_evaluator
test_evaluator
.
finish
()
updater
.
finishPass
()
updater
.
finishPass
()
m
.
finish
()
m
.
finish
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录