Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
dc7f2031
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dc7f2031
编写于
2月 13, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add comments for functions
上级
ce49124d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
50 addition
and
13 deletion
+50
-13
python/paddle/v2/trainer.py
python/paddle/v2/trainer.py
+50
-13
未找到文件。
python/paddle/v2/trainer.py
浏览文件 @
dc7f2031
...
...
@@ -16,6 +16,10 @@ class BaseEvent(object):
class
CompleteTrainOneBatch
(
BaseEvent
):
"""
Event On One Batch Training Complete.
"""
def
__init__
(
self
,
pass_id
,
batch_id
,
cost
):
self
.
pass_id
=
pass_id
self
.
batch_id
=
batch_id
...
...
@@ -38,6 +42,11 @@ class ITrainer(object):
class
SGDTrainer
(
ITrainer
):
def
__init__
(
self
,
update_equation
):
"""
Simple SGD Trainer.
:param update_equation: Maybe we should give a DSL for update equation?
"""
if
not
isinstance
(
update_equation
,
paddle
.
v2
.
optimizer
.
Optimizer
):
raise
ValueError
()
...
...
@@ -52,6 +61,21 @@ class SGDTrainer(ITrainer):
event_handler
=
None
,
batch_size
=
32
,
data_types
=
None
):
"""
Training method. Will train num_passes of input data.
:param train_data_reader:
:param topology: Network Topology, a protobuf ModelConfig message.
:param parameters: The parameter pools.
:param num_passes: The total train passes.
:param test_data_reader:
:param event_handler: Event handler. A method will be invoked when event
occurred.
:type event_handler: (BaseEvent) => None
:param batch_size: Not important, will be removed after data refactor.
:param data_types: Not important, will be removed after data refactor.
:return:
"""
if
event_handler
is
None
:
event_handler
=
default_event_handler
...
...
@@ -66,6 +90,9 @@ class SGDTrainer(ITrainer):
assert
isinstance
(
updater
,
api
.
ParameterUpdater
)
updater
.
init
(
gm
)
gm
.
start
()
out_args
=
api
.
Arguments
.
createArguments
(
0
)
data_types_lists
=
[]
for
each
in
topology
.
input_layer_names
:
if
each
not
in
data_types
:
...
...
@@ -74,22 +101,11 @@ class SGDTrainer(ITrainer):
converter
=
DataProviderConverter
(
input_types
=
data_types_lists
)
def
input_reorder
(
func
):
for
item
in
func
():
retv
=
[]
for
__layer_name__
in
topology
.
input_layer_names
:
retv
.
append
(
item
[
__layer_name__
])
yield
retv
gm
.
start
()
out_args
=
api
.
Arguments
.
createArguments
(
0
)
for
pass_id
in
xrange
(
num_passes
):
updater
.
startPass
()
for
batch_id
,
data_batch
in
enumerate
(
__generator_to_batch__
(
input_reorder
(
train_data_reader
),
batch_size
=
batch_size
)):
__data_reader_to_batch__
(
train_data_reader
,
batch_size
,
topology
)):
pass_type
=
updater
.
startBatch
(
len
(
data_batch
))
gm
.
forwardBackward
(
converter
(
data_batch
),
out_args
,
pass_type
)
for
each_param
in
gm
.
getParameters
():
...
...
@@ -108,7 +124,25 @@ class SGDTrainer(ITrainer):
gm
.
finish
()
def
__data_reader_to_batch__
(
reader
,
batch_size
,
topology
):
"""
This function is not important, and will be removed when data refactored.
"""
def
input_reorder
(
func
):
for
item
in
func
():
retv
=
[]
for
__layer_name__
in
topology
.
input_layer_names
:
retv
.
append
(
item
[
__layer_name__
])
yield
retv
return
__generator_to_batch__
(
input_reorder
(
reader
),
batch_size
=
batch_size
)
def
__generator_to_batch__
(
generator
,
batch_size
):
"""
This function is not important, and will be removed when data refactored.
"""
ret_val
=
list
()
for
each_item
in
generator
:
ret_val
.
append
(
each_item
)
...
...
@@ -139,6 +173,9 @@ def __copy_parameter_from_pool__(gm, pool):
def
__check_train_args__
(
train_data_reader
,
topology
,
parameters
,
test_data_reader
,
event_handler
,
**
kwargs
):
"""
Check train function's argument types
"""
if
not
callable
(
train_data_reader
)
or
not
isinstance
(
train_data_reader
(),
collections
.
Iterator
):
raise
ValueError
(
'train_data_reader should be a function, '
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录