Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8dc4c053
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8dc4c053
编写于
2月 16, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add comments
上级
f9ea5864
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
138 addition
and
7 deletion
+138
-7
python/paddle/v2/event.py
python/paddle/v2/event.py
+3
-0
python/paddle/v2/parameters.py
python/paddle/v2/parameters.py
+111
-6
python/paddle/v2/trainer.py
python/paddle/v2/trainer.py
+24
-1
未找到文件。
python/paddle/v2/event.py
浏览文件 @
8dc4c053
"""
All training events.
"""
__all__
=
[
'EndIteration'
]
...
...
python/paddle/v2/parameters.py
浏览文件 @
8dc4c053
...
...
@@ -21,13 +21,29 @@ def create(*topologies):
'create must pass a topologies which type is ModelConfig'
)
for
param
in
topo
.
parameters
:
pool
.
append_config
(
param
)
pool
.
__append_config__
(
param
)
return
pool
class
Parameters
(
object
):
"""
The parameters
Parameters is a dictionary contains Paddle's parameter. The key of
Parameters is the name of parameter. The value of Parameters is a plain
:code:`numpy.ndarry` .
Basically usage is
.. code-block:: python
data = paddle.layers.data(...)
...
out = paddle.layers.fc(...)
parameters = paddle.parameters.create(out)
parameter_names = parameters.names()
fc_mat = parameters.get('fc')
print fc_mat
"""
def
__init__
(
self
):
...
...
@@ -35,7 +51,16 @@ class Parameters(object):
self
.
__gradient_machines__
=
[]
self
.
__tmp_params__
=
[]
def
append_config
(
self
,
param_conf
):
def
__append_config__
(
self
,
param_conf
):
"""
Append a parameter configuration. It used to initialize Parameters and
should be invoked only in paddle.parameters.create
:param param_conf: The parameter configuration in protobuf
:type param_conf: ParameterConfig
:return: Nothing
"""
if
not
isinstance
(
param_conf
,
ParameterConfig
):
raise
ValueError
(
"param_conf must be paddle.proto.ParameterConfig"
)
...
...
@@ -45,18 +70,55 @@ class Parameters(object):
self
.
__param_conf__
[
param_conf
.
name
]
=
param_conf
def
keys
(
self
):
"""
keys are the names of each parameter.
:return: list of parameter name
:rtype: list
"""
return
self
.
__param_conf__
.
keys
()
def
names
(
self
):
"""
names of each parameter.
:return: list of parameter name
:rtype: list
"""
return
self
.
keys
()
def
has_key
(
self
,
key
):
"""
has_key return true if there are such parameter name == key
:param key: Parameter name
:type key: basestring
:return: True if contains such key
"""
return
key
in
self
.
__param_conf__
.
keys
()
def
__iter__
(
self
):
"""
Return an iterator of parameter name. It is used by `for loop`
or `in` operator.
.. code-block:: python
parameters = paddle.parameters.create(...)
if "fc_param" in parameters:
print 'OK'
:return: an iterator of parameter name
:rtype: iterator
"""
return
iter
(
self
.
__param_conf__
)
def
__getitem__
(
self
,
key
):
"""
Get parameter by parameter name. It uses Python dict syntax.
:note: It will always copy the parameter from C++ side.
:param key: Parameter name
:type key: basestring
:return: parameter value
:rtype: np.ndarray
"""
shape
=
self
.
get_shape
(
key
)
if
len
(
self
.
__gradient_machines__
)
==
0
:
...
...
@@ -77,20 +139,37 @@ class Parameters(object):
raise
RuntimeError
(
"Unexpected branch"
)
def
get_shape
(
self
,
key
):
"""
get shape of the parameter.
:param key: parameter name
:type key: basestring
:return: parameter's shape
:rtype: tuple
"""
if
not
isinstance
(
key
,
basestring
):
raise
ValueError
(
"parameter name should be string"
)
if
not
self
.
has_key
(
key
):
raise
ValueError
(
"No such parameter %s"
%
key
)
conf
=
self
.
__param_conf__
[
key
]
return
map
(
int
,
conf
.
dims
)
return
tuple
(
map
(
int
,
conf
.
dims
)
)
def
__setitem__
(
self
,
key
,
value
):
"""
Set parameter by parameter name & value. It use Python dict syntax.
:note: It will always copy the parameter to C++ side.
:param key: Parameter name
:type key: basestring
:param value: Parameter matrix.
:type value: np.ndarray
:return: Nothing
"""
if
not
isinstance
(
value
,
np
.
ndarray
):
raise
ValueError
(
"Must return ndarray"
)
value
=
value
.
astype
(
dtype
=
np
.
float32
)
shape
=
self
.
get_shape
(
key
)
if
not
reduce
(
lambda
a
,
b
:
a
and
b
,
map
(
lambda
x
:
x
[
0
]
==
x
[
1
],
zip
(
value
.
shape
,
shape
))):
if
value
.
shape
!=
shape
:
raise
ValueError
(
"Value shape mismatch, expect %s, should %s"
%
(
shape
,
value
.
shape
))
...
...
@@ -102,12 +181,38 @@ class Parameters(object):
key
,
value
)
def
get
(
self
,
parameter_name
):
"""
Get parameter by parameter name.
:note: It will always copy the parameter from C++ side.
:param parameter_name: parameter name
:type parameter_name: basestring
:return: The parameter matrix.
:rtype: np.ndarray
"""
return
self
.
__getitem__
(
key
=
parameter_name
)
def
set
(
self
,
parameter_name
,
value
):
"""
Set parameter by parameter name & matrix.
:param parameter_name: parameter name
:type parameter_name: basestring
:param value: parameter matrix
:type value: np.ndarray
:return: Nothing.
"""
self
.
__setitem__
(
key
=
parameter_name
,
value
=
value
)
def
append_gradient_machine
(
self
,
gradient_machine
):
"""
append gradient machine to parameters. This method is used internally in
Trainer.train.
:param gradient_machine: Paddle C++ GradientMachine object.
:type gradient_machine: api.GradientMachine
:return:
"""
if
not
isinstance
(
gradient_machine
,
api
.
GradientMachine
):
raise
ValueError
(
"gradient_machine should be api.GradientMachine"
)
...
...
python/paddle/v2/trainer.py
浏览文件 @
8dc4c053
...
...
@@ -12,16 +12,38 @@ __all__ = ['ITrainer', 'SGD']
def
default_event_handler
(
event
):
"""
Default event handler. It will print some log and save mode.
TODO(yuyang18): Complete it!
:param event:
:return:
"""
pass
class
ITrainer
(
object
):
"""
The interface of Trainer. The only exposed method is `train`.
"""
def
train
(
self
,
train_data_reader
,
topology
,
parameters
,
test_data_reader
=
None
,
event_handler
=
None
):
"""
train method.
:param train_data_reader:
:param topology:
:param parameters:
:param test_data_reader:
:param event_handler:
:return:
"""
raise
NotImplementedError
()
...
...
@@ -30,7 +52,8 @@ class SGD(ITrainer):
"""
Simple SGD Trainer.
:param update_equation: Maybe we should give a DSL for update equation?
:param update_equation: The optimizer object.
:type update_equation: v2_optimizer.Optimizer
"""
if
not
isinstance
(
update_equation
,
v2_optimizer
.
Optimizer
):
raise
ValueError
(
"update equation parameter must be "
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录