Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f9ea5864
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f9ea5864
编写于
2月 16, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add get/set method
上级
9646f2d3
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
16 addition
and
4 deletion
+16
-4
demo/mnist/api_train_v2.py
demo/mnist/api_train_v2.py
+3
-3
python/paddle/v2/parameters.py
python/paddle/v2/parameters.py
+13
-1
未找到文件。
demo/mnist/api_train_v2.py
浏览文件 @
f9ea5864
...
@@ -28,16 +28,16 @@ def main():
...
@@ -28,16 +28,16 @@ def main():
topology
=
parse_network_config
(
network_config
)
topology
=
parse_network_config
(
network_config
)
parameters
=
paddle
.
parameters
.
create
(
topology
)
parameters
=
paddle
.
parameters
.
create
(
topology
)
for
param_name
in
parameters
.
keys
():
for
param_name
in
parameters
.
keys
():
array
=
parameters
[
param_name
]
array
=
parameters
.
get
(
param_name
)
array
[:]
=
numpy
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
array
.
shape
)
array
[:]
=
numpy
.
random
.
uniform
(
low
=-
1.0
,
high
=
1.0
,
size
=
array
.
shape
)
parameters
[
param_name
]
=
array
parameters
.
set
(
parameter_name
=
param_name
,
value
=
array
)
adam_optimizer
=
paddle
.
optimizer
.
Optimizer
(
adam_optimizer
=
paddle
.
optimizer
.
Optimizer
(
learning_rate
=
0.01
,
learning_method
=
AdamOptimizer
())
learning_rate
=
0.01
,
learning_method
=
AdamOptimizer
())
def
event_handler
(
event
):
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
para
=
parameters
[
'___fc_layer_2__.w0'
]
para
=
parameters
.
get
(
'___fc_layer_2__.w0'
)
print
"Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f"
%
(
print
"Pass %d, Batch %d, Cost %f, Weight Mean Of Fc 2 is %f"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
para
.
mean
())
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
para
.
mean
())
...
...
python/paddle/v2/parameters.py
浏览文件 @
f9ea5864
...
@@ -26,6 +26,10 @@ def create(*topologies):
...
@@ -26,6 +26,10 @@ def create(*topologies):
class
Parameters
(
object
):
class
Parameters
(
object
):
"""
The parameters
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
__param_conf__
=
dict
()
self
.
__param_conf__
=
dict
()
self
.
__gradient_machines__
=
[]
self
.
__gradient_machines__
=
[]
...
@@ -66,7 +70,8 @@ class Parameters(object):
...
@@ -66,7 +70,8 @@ class Parameters(object):
assert
isinstance
(
param
,
api
.
Parameter
)
assert
isinstance
(
param
,
api
.
Parameter
)
val
=
param
.
getBuf
(
api
.
PARAMETER_VALUE
)
val
=
param
.
getBuf
(
api
.
PARAMETER_VALUE
)
assert
isinstance
(
val
,
api
.
Vector
)
assert
isinstance
(
val
,
api
.
Vector
)
return
val
.
copyToNumpyArray
().
reshape
(
shape
=
shape
)
val
=
val
.
copyToNumpyArray
()
return
val
# else continue
# else continue
raise
RuntimeError
(
"Unexpected branch"
)
raise
RuntimeError
(
"Unexpected branch"
)
...
@@ -96,6 +101,12 @@ class Parameters(object):
...
@@ -96,6 +101,12 @@ class Parameters(object):
__copy_parameter_to_gradient_machine__
(
each_gradient_machine
,
__copy_parameter_to_gradient_machine__
(
each_gradient_machine
,
key
,
value
)
key
,
value
)
def
get
(
self
,
parameter_name
):
return
self
.
__getitem__
(
key
=
parameter_name
)
def
set
(
self
,
parameter_name
,
value
):
self
.
__setitem__
(
key
=
parameter_name
,
value
=
value
)
def
append_gradient_machine
(
self
,
gradient_machine
):
def
append_gradient_machine
(
self
,
gradient_machine
):
if
not
isinstance
(
gradient_machine
,
api
.
GradientMachine
):
if
not
isinstance
(
gradient_machine
,
api
.
GradientMachine
):
raise
ValueError
(
"gradient_machine should be api.GradientMachine"
)
raise
ValueError
(
"gradient_machine should be api.GradientMachine"
)
...
@@ -108,6 +119,7 @@ class Parameters(object):
...
@@ -108,6 +119,7 @@ class Parameters(object):
except
ValueError
:
except
ValueError
:
# If no such parameter in gradient machine, then don't copy
# If no such parameter in gradient machine, then don't copy
pass
pass
self
.
__gradient_machines__
.
append
(
gradient_machine
)
def
__get_parameter_in_gradient_machine__
(
gradient_machine
,
name
):
def
__get_parameter_in_gradient_machine__
(
gradient_machine
,
name
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录