Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
91f13e48
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
91f13e48
编写于
3月 01, 2017
作者:
J
jacquesqiao
提交者:
GitHub
3月 01, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #1465 from reyoung/feature/tester
Paddle.V2.Trainer.test method complete.
上级
b63d38d1
b9f8cc06
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
107 addition
and
73 deletion
+107
-73
demo/mnist/api_train_v2.py
demo/mnist/api_train_v2.py
+13
-10
python/paddle/v2/dataset/mnist.py
python/paddle/v2/dataset/mnist.py
+2
-2
python/paddle/v2/event.py
python/paddle/v2/event.py
+9
-1
python/paddle/v2/topology.py
python/paddle/v2/topology.py
+25
-24
python/paddle/v2/trainer.py
python/paddle/v2/trainer.py
+58
-36
未找到文件。
demo/mnist/api_train_v2.py
浏览文件 @
91f13e48
...
@@ -20,26 +20,29 @@ def main():
...
@@ -20,26 +20,29 @@ def main():
adam_optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.01
)
adam_optimizer
=
paddle
.
optimizer
.
Adam
(
learning_rate
=
0.01
)
trainer
=
paddle
.
trainer
.
SGD
(
cost
=
cost
,
parameters
=
parameters
,
update_equation
=
adam_optimizer
)
def
event_handler
(
event
):
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
100
==
0
:
if
event
.
batch_id
%
1000
==
0
:
print
"Pass %d, Batch %d, Cost %f, %s"
%
(
result
=
trainer
.
test
(
reader
=
paddle
.
reader
.
batched
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
256
))
print
"Pass %d, Batch %d, Cost %f, %s, Testing metrics %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
,
result
.
metrics
)
else
:
else
:
pass
pass
trainer
=
paddle
.
trainer
.
SGD
(
update_equation
=
adam_optimizer
)
trainer
.
train
(
trainer
.
train
(
reader
=
paddle
.
reader
.
batched
(
reader
=
paddle
.
reader
.
batched
(
paddle
.
reader
.
shuffle
(
paddle
.
reader
.
shuffle
(
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
8192
),
paddle
.
dataset
.
mnist
.
train
(),
buf_size
=
8192
),
batch_size
=
32
),
batch_size
=
32
),
cost
=
cost
,
event_handler
=
event_handler
)
parameters
=
parameters
,
event_handler
=
event_handler
,
reader_dict
=
{
images
.
name
:
0
,
label
.
name
:
1
})
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
python/paddle/v2/dataset/mnist.py
浏览文件 @
91f13e48
...
@@ -9,9 +9,9 @@ __all__ = ['train', 'test']
...
@@ -9,9 +9,9 @@ __all__ = ['train', 'test']
URL_PREFIX
=
'http://yann.lecun.com/exdb/mnist/'
URL_PREFIX
=
'http://yann.lecun.com/exdb/mnist/'
TEST_IMAGE_URL
=
URL_PREFIX
+
't10k-images-idx3-ubyte.gz'
TEST_IMAGE_URL
=
URL_PREFIX
+
't10k-images-idx3-ubyte.gz'
TEST_IMAGE_MD5
=
'
25e3cc63507ef6e98d5dc541e8672bb6
'
TEST_IMAGE_MD5
=
'
9fb629c4189551a2d022fa330f9573f3
'
TEST_LABEL_URL
=
URL_PREFIX
+
't10k-labels-idx1-ubyte.gz'
TEST_LABEL_URL
=
URL_PREFIX
+
't10k-labels-idx1-ubyte.gz'
TEST_LABEL_MD5
=
'
4e9511fe019b2189026bd0421ba7b688
'
TEST_LABEL_MD5
=
'
ec29112dd5afa0611ce80d1b7f02629c
'
TRAIN_IMAGE_URL
=
URL_PREFIX
+
'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_URL
=
URL_PREFIX
+
'train-images-idx3-ubyte.gz'
TRAIN_IMAGE_MD5
=
'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_IMAGE_MD5
=
'f68b3c2dcbeaaa9fbdd348bbdeb94873'
TRAIN_LABEL_URL
=
URL_PREFIX
+
'train-labels-idx1-ubyte.gz'
TRAIN_LABEL_URL
=
URL_PREFIX
+
'train-labels-idx1-ubyte.gz'
...
...
python/paddle/v2/event.py
浏览文件 @
91f13e48
...
@@ -11,7 +11,10 @@ There are:
...
@@ -11,7 +11,10 @@ There are:
TODO(yuyang18): Complete it!
TODO(yuyang18): Complete it!
"""
"""
import
py_paddle.swig_paddle
as
api
import
py_paddle.swig_paddle
as
api
__all__
=
[
'EndIteration'
,
'BeginIteration'
,
'BeginPass'
,
'EndPass'
]
__all__
=
[
'EndIteration'
,
'BeginIteration'
,
'BeginPass'
,
'EndPass'
,
'TestResult'
]
class
WithMetric
(
object
):
class
WithMetric
(
object
):
...
@@ -30,6 +33,11 @@ class WithMetric(object):
...
@@ -30,6 +33,11 @@ class WithMetric(object):
return
retv
return
retv
class
TestResult
(
WithMetric
):
def
__init__
(
self
,
evaluator
):
super
(
TestResult
,
self
).
__init__
(
evaluator
)
class
BeginPass
(
object
):
class
BeginPass
(
object
):
"""
"""
Event On One Pass Training Start.
Event On One Pass Training Start.
...
...
python/paddle/v2/topology.py
浏览文件 @
91f13e48
...
@@ -21,6 +21,14 @@ import layer as v2_layer
...
@@ -21,6 +21,14 @@ import layer as v2_layer
__all__
=
[
'Topology'
]
__all__
=
[
'Topology'
]
def
__bfs_travel__
(
callback
,
*
layers
):
for
each_layer
in
layers
:
__break__
=
callback
(
each_layer
)
if
__break__
:
return
__bfs_travel__
(
callback
,
*
each_layer
.
__parent_layers__
.
values
())
class
Topology
(
object
):
class
Topology
(
object
):
"""
"""
Topology is used to store the information about all layers
Topology is used to store the information about all layers
...
@@ -46,21 +54,17 @@ class Topology(object):
...
@@ -46,21 +54,17 @@ class Topology(object):
:param name:
:param name:
:return:
:return:
"""
"""
result_layer
=
[]
result_layer
=
[
None
]
def
find_layer_by_name
(
layer
,
layer_name
):
def
__impl__
(
l
):
if
len
(
result_layer
)
==
1
:
if
l
.
name
==
name
:
return
result_layer
[
0
]
=
l
elif
layer
.
name
==
layer_name
:
return
True
# break
result_layer
.
append
(
layer
)
return
False
else
:
for
parent_layer
in
layer
.
__parent_layers__
.
values
():
find_layer_by_name
(
parent_layer
,
layer_name
)
for
layer
in
self
.
layers
:
__bfs_travel__
(
__impl__
,
*
self
.
layers
)
find_layer_by_name
(
layer
,
name
)
if
result_layer
[
0
]
is
None
:
raise
ValueError
(
"No such layer %s"
%
name
)
assert
len
(
result_layer
)
==
1
return
result_layer
[
0
]
return
result_layer
[
0
]
def
data_layers
(
self
):
def
data_layers
(
self
):
...
@@ -68,17 +72,13 @@ class Topology(object):
...
@@ -68,17 +72,13 @@ class Topology(object):
get all data layer
get all data layer
:return:
:return:
"""
"""
data_layers
=
set
()
data_layers
=
dict
()
def
find_data_layer
(
layer
):
if
isinstance
(
layer
,
v2_layer
.
DataLayerV2
):
data_layers
.
add
(
layer
)
for
parent_layer
in
layer
.
__parent_layers__
.
values
():
find_data_layer
(
parent_layer
)
for
layer
in
self
.
layers
:
def
__impl__
(
l
):
find_data_layer
(
layer
)
if
isinstance
(
l
,
v2_layer
.
DataLayerV2
):
data_layers
[
l
.
name
]
=
l
__bfs_travel__
(
__impl__
,
*
self
.
layers
)
return
data_layers
return
data_layers
def
data_type
(
self
):
def
data_type
(
self
):
...
@@ -86,8 +86,9 @@ class Topology(object):
...
@@ -86,8 +86,9 @@ class Topology(object):
get data_type from proto, such as:
get data_type from proto, such as:
[('image', dense_vector(768)), ('label', integer_value(10))]
[('image', dense_vector(768)), ('label', integer_value(10))]
"""
"""
return
[(
data_layer
.
name
,
data_layer
.
type
)
data_layers
=
self
.
data_layers
()
for
data_layer
in
self
.
data_layers
()]
return
[(
nm
,
data_layers
[
nm
].
type
)
for
nm
in
self
.
proto
().
input_layer_names
]
def
__check_layer_type__
(
layer
):
def
__check_layer_type__
(
layer
):
...
...
python/paddle/v2/trainer.py
浏览文件 @
91f13e48
...
@@ -42,25 +42,35 @@ class ITrainer(object):
...
@@ -42,25 +42,35 @@ class ITrainer(object):
class
SGD
(
ITrainer
):
class
SGD
(
ITrainer
):
def
__init__
(
self
,
update_equation
):
def
__init__
(
self
,
cost
,
parameters
,
update_equation
):
"""
"""
Simple SGD Trainer.
Simple SGD Trainer.
:param update_equation: The optimizer object.
:param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer
:type update_equation: v2_optimizer.Optimizer
"""
"""
if
not
isinstance
(
parameters
,
v2_parameters
.
Parameters
):
raise
TypeError
(
'parameters should be parameters'
)
if
not
isinstance
(
update_equation
,
v2_optimizer
.
Optimizer
):
if
not
isinstance
(
update_equation
,
v2_optimizer
.
Optimizer
):
raise
ValueError
(
"update equation parameter must be "
raise
TypeError
(
"update equation parameter must be "
"paddle.v2.optimizer.Optimizer"
)
"paddle.v2.optimizer.Optimizer"
)
topology
=
Topology
(
cost
)
self
.
__optimizer__
=
update_equation
self
.
__optimizer__
=
update_equation
self
.
__topology__
=
topology
self
.
__parameters__
=
parameters
self
.
__topology_in_proto__
=
topology
.
proto
()
self
.
__data_types__
=
topology
.
data_type
()
gm
=
api
.
GradientMachine
.
createFromConfigProto
(
self
.
__topology_in_proto__
,
api
.
CREATE_MODE_NORMAL
,
self
.
__optimizer__
.
enable_types
())
assert
isinstance
(
gm
,
api
.
GradientMachine
)
parameters
.
append_gradient_machine
(
gm
)
self
.
__gradient_machine__
=
gm
self
.
__gradient_machine__
.
randParameters
()
def
train
(
self
,
def
train
(
self
,
reader
,
num_passes
=
1
,
event_handler
=
None
,
reader_dict
=
None
):
reader
,
cost
,
parameters
,
num_passes
=
1
,
event_handler
=
None
,
reader_dict
=
None
):
"""
"""
Training method. Will train num_passes of input data.
Training method. Will train num_passes of input data.
...
@@ -76,27 +86,22 @@ class SGD(ITrainer):
...
@@ -76,27 +86,22 @@ class SGD(ITrainer):
if
event_handler
is
None
:
if
event_handler
is
None
:
event_handler
=
default_event_handler
event_handler
=
default_event_handler
topology
=
Topology
(
cost
)
if
reader_dict
is
None
:
reader_dict
=
self
.
default_reader_dict
()
__check_train_args__
(
**
locals
())
__check_train_args__
(
**
locals
())
gm
=
api
.
GradientMachine
.
createFromConfigProto
(
topology
.
proto
(),
api
.
CREATE_MODE_NORMAL
,
self
.
__optimizer__
.
enable_types
())
assert
isinstance
(
gm
,
api
.
GradientMachine
)
parameters
.
append_gradient_machine
(
gm
)
gm
.
randParameters
()
updater
=
self
.
__optimizer__
.
create_local_updater
()
updater
=
self
.
__optimizer__
.
create_local_updater
()
updater
.
init
(
gm
)
updater
.
init
(
self
.
__gradient_machine__
)
gm
.
start
()
self
.
__gradient_machine__
.
start
()
batch_evaluator
=
gm
.
makeEvaluator
()
batch_evaluator
=
self
.
__gradient_machine__
.
makeEvaluator
()
assert
isinstance
(
batch_evaluator
,
api
.
Evaluator
)
assert
isinstance
(
batch_evaluator
,
api
.
Evaluator
)
pass_evaluator
=
gm
.
makeEvaluator
()
pass_evaluator
=
self
.
__gradient_machine__
.
makeEvaluator
()
assert
isinstance
(
pass_evaluator
,
api
.
Evaluator
)
assert
isinstance
(
pass_evaluator
,
api
.
Evaluator
)
out_args
=
api
.
Arguments
.
createArguments
(
0
)
out_args
=
api
.
Arguments
.
createArguments
(
0
)
feeder
=
DataFeeder
(
topology
.
data_type
()
,
reader_dict
)
feeder
=
DataFeeder
(
self
.
__data_types__
,
reader_dict
)
for
pass_id
in
xrange
(
num_passes
):
for
pass_id
in
xrange
(
num_passes
):
event_handler
(
v2_event
.
BeginPass
(
pass_id
))
event_handler
(
v2_event
.
BeginPass
(
pass_id
))
...
@@ -104,16 +109,18 @@ class SGD(ITrainer):
...
@@ -104,16 +109,18 @@ class SGD(ITrainer):
updater
.
startPass
()
updater
.
startPass
()
for
batch_id
,
data_batch
in
enumerate
(
reader
()):
for
batch_id
,
data_batch
in
enumerate
(
reader
()):
pass_type
=
updater
.
startBatch
(
len
(
data_batch
))
pass_type
=
updater
.
startBatch
(
len
(
data_batch
))
gm
.
forwardBackward
(
feeder
(
data_batch
),
out_args
,
pass_type
)
self
.
__gradient_machine__
.
forwardBackward
(
feeder
(
data_batch
),
out_args
,
pass_type
)
batch_evaluator
.
start
()
batch_evaluator
.
start
()
event_handler
(
event_handler
(
v2_event
.
BeginIteration
(
v2_event
.
BeginIteration
(
pass_id
=
pass_id
,
batch_id
=
batch_id
))
pass_id
=
pass_id
,
batch_id
=
batch_id
))
pass_type
=
updater
.
startBatch
(
len
(
data_batch
))
pass_type
=
updater
.
startBatch
(
len
(
data_batch
))
gm
.
forwardBackward
(
feeder
(
data_batch
),
out_args
,
pass_type
)
self
.
__gradient_machine__
.
forwardBackward
(
gm
.
eval
(
pass_evaluator
)
feeder
(
data_batch
),
out_args
,
pass_type
)
gm
.
eval
(
batch_evaluator
)
self
.
__gradient_machine__
.
eval
(
pass_evaluator
)
for
each_param
in
gm
.
getParameters
():
self
.
__gradient_machine__
.
eval
(
batch_evaluator
)
for
each_param
in
self
.
__gradient_machine__
.
getParameters
():
updater
.
update
(
each_param
)
updater
.
update
(
each_param
)
# Get cost. We use numpy to calculate total cost for this batch.
# Get cost. We use numpy to calculate total cost for this batch.
cost_vec
=
out_args
.
getSlotValue
(
0
)
cost_vec
=
out_args
.
getSlotValue
(
0
)
...
@@ -131,22 +138,37 @@ class SGD(ITrainer):
...
@@ -131,22 +138,37 @@ class SGD(ITrainer):
updater
.
finishPass
()
updater
.
finishPass
()
pass_evaluator
.
finish
()
pass_evaluator
.
finish
()
event_handler
(
v2_event
.
EndPass
(
pass_id
,
evaluator
=
pass_evaluator
))
event_handler
(
v2_event
.
EndPass
(
pass_id
,
evaluator
=
pass_evaluator
))
gm
.
finish
()
self
.
__gradient_machine__
.
finish
()
def
default_reader_dict
(
self
):
reader_dict
=
dict
()
for
i
,
tp
in
enumerate
(
self
.
__data_types__
):
reader_dict
[
tp
[
0
]]
=
i
return
reader_dict
def
test
(
self
,
reader
,
reader_dict
=
None
):
if
reader_dict
is
None
:
reader_dict
=
self
.
default_reader_dict
()
feeder
=
DataFeeder
(
self
.
__data_types__
,
reader_dict
)
evaluator
=
self
.
__gradient_machine__
.
makeEvaluator
()
out_args
=
api
.
Arguments
.
createArguments
(
0
)
evaluator
.
start
()
for
data_batch
in
reader
():
self
.
__gradient_machine__
.
forward
(
feeder
(
data_batch
),
out_args
,
api
.
PASS_TEST
)
self
.
__gradient_machine__
.
eval
(
evaluator
)
evaluator
.
finish
()
return
v2_event
.
TestResult
(
evaluator
=
evaluator
)
def
__check_train_args__
(
reader
,
topology
,
parameters
,
event_handler
,
**
kwargs
):
def
__check_train_args__
(
reader
,
event_handler
,
**
kwargs
):
"""
"""
Check train function's argument types
Check train function's argument types
"""
"""
if
not
callable
(
reader
)
or
not
isinstance
(
reader
(),
collections
.
Iterator
):
if
not
callable
(
reader
)
or
not
isinstance
(
reader
(),
collections
.
Iterator
):
raise
TypeError
(
'train_data_reader should be a function, '
raise
TypeError
(
'train_data_reader should be a function, '
'which can return a iterator'
)
'which can return a iterator'
)
if
not
isinstance
(
topology
,
Topology
):
raise
TypeError
(
'topology should be a model config'
)
if
not
isinstance
(
parameters
,
v2_parameters
.
Parameters
):
raise
TypeError
(
'parameters should be a parameter pool'
)
if
not
callable
(
event_handler
):
if
not
callable
(
event_handler
):
raise
TypeError
(
'event handler should be a function'
)
raise
TypeError
(
'event handler should be a function'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录