Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
82103508
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
82103508
编写于
4月 15, 2017
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add getParametersRemote for ParameterUpdater in api
上级
64bfd814
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
25 addition
and
13 deletion
+25
-13
paddle/api/PaddleAPI.h
paddle/api/PaddleAPI.h
+7
-0
paddle/api/ParameterUpdater.cpp
paddle/api/ParameterUpdater.cpp
+4
-0
python/paddle/v2/topology.py
python/paddle/v2/topology.py
+4
-2
python/paddle/v2/trainer.py
python/paddle/v2/trainer.py
+10
-11
未找到文件。
paddle/api/PaddleAPI.h
浏览文件 @
82103508
...
...
@@ -859,6 +859,13 @@ public:
*/
void
update
(
Parameter
*
param
);
/**
* @breif only get required sparse rows by default.
* @param fullSize: get full matrix parameter if *fullSize* set
* @param apply: get PARAMETER_APPLY on pserver if *apply* set
*/
void
getParametersRemote
(
bool
fullSize
=
false
,
bool
apply
=
false
);
/**
* @brief restore the average parameter.
* @note It is only used in AverageOptimizer. Restore will get the current
...
...
paddle/api/ParameterUpdater.cpp
浏览文件 @
82103508
...
...
@@ -72,6 +72,10 @@ void ParameterUpdater::update(Parameter *param) {
m
->
updater
->
update
(
paddleParam
);
}
void
ParameterUpdater
::
getParametersRemote
(
bool
fullSize
,
bool
apply
)
{
m
->
updater
->
getParametersRemote
(
fullSize
,
apply
);
}
void
ParameterUpdater
::
restore
()
{
m
->
updater
->
restore
();
}
void
ParameterUpdater
::
apply
()
{
m
->
updater
->
apply
();
}
...
...
python/paddle/v2/topology.py
浏览文件 @
82103508
...
...
@@ -78,10 +78,12 @@ class Topology(object):
check if any parameter require to use sparse_update
:return:
"""
use_sparse
=
False
for
parameter
in
self
.
__model_config__
.
parameters
:
if
parameter
.
sparse_update
or
parameter
.
sparse_remote_update
:
return
True
return
False
use_sparse
=
True
break
return
use_sparse
def
proto
(
self
):
return
self
.
__model_config__
...
...
python/paddle/v2/trainer.py
浏览文件 @
82103508
...
...
@@ -65,7 +65,6 @@ class SGD(object):
self
.
__use_sparse_updater__
=
self
.
__topology__
.
use_sparse_updater
()
# # In local mode, disable sparse_remote_update.
if
is_local
:
self
.
__use_sparse_updater__
=
False
for
param
in
self
.
__topology_in_proto__
.
parameters
:
if
param
.
sparse_remote_update
:
param
.
sparse_remote_update
=
False
...
...
@@ -100,11 +99,11 @@ class SGD(object):
__check_train_args__
(
**
locals
())
if
self
.
__is_local__
:
updater
=
self
.
__optimizer__
.
create_local_updater
()
parameter_
updater
=
self
.
__optimizer__
.
create_local_updater
()
else
:
updater
=
self
.
__optimizer__
.
create_remote_updater
(
parameter_
updater
=
self
.
__optimizer__
.
create_remote_updater
(
num_passes
,
self
.
__use_sparse_updater__
)
updater
.
init
(
self
.
__gradient_machine__
)
parameter_
updater
.
init
(
self
.
__gradient_machine__
)
self
.
__gradient_machine__
.
start
()
batch_evaluator
=
self
.
__gradient_machine__
.
makeEvaluator
()
...
...
@@ -116,26 +115,26 @@ class SGD(object):
for
pass_id
in
xrange
(
num_passes
):
event_handler
(
v2_event
.
BeginPass
(
pass_id
))
pass_evaluator
.
start
()
updater
.
startPass
()
parameter_
updater
.
startPass
()
for
batch_id
,
data_batch
in
enumerate
(
reader
()):
batch_evaluator
.
start
()
event_handler
(
v2_event
.
BeginIteration
(
pass_id
=
pass_id
,
batch_id
=
batch_id
))
pass_type
=
updater
.
startBatch
(
len
(
data_batch
))
if
self
.
__use_sparse_updater__
:
pass_type
=
parameter_
updater
.
startBatch
(
len
(
data_batch
))
if
self
.
__use_sparse_updater__
and
not
self
.
__is_local__
:
self
.
__gradient_machine__
.
prefetch
(
feeder
(
data_batch
))
updater
.
getParametersRemote
()
parameter_
updater
.
getParametersRemote
()
self
.
__gradient_machine__
.
forwardBackward
(
feeder
(
data_batch
),
out_args
,
pass_type
)
self
.
__gradient_machine__
.
eval
(
pass_evaluator
)
self
.
__gradient_machine__
.
eval
(
batch_evaluator
)
for
each_param
in
self
.
__gradient_machine__
.
getNonStaticParameters
(
):
updater
.
update
(
each_param
)
parameter_
updater
.
update
(
each_param
)
cost_sum
=
out_args
.
sum
()
cost
=
cost_sum
/
len
(
data_batch
)
updater
.
finishBatch
(
cost
)
parameter_
updater
.
finishBatch
(
cost
)
batch_evaluator
.
finish
()
event_handler
(
v2_event
.
EndIteration
(
...
...
@@ -144,7 +143,7 @@ class SGD(object):
cost
=
cost
,
evaluator
=
batch_evaluator
))
updater
.
finishPass
()
parameter_
updater
.
finishPass
()
pass_evaluator
.
finish
()
event_handler
(
v2_event
.
EndPass
(
pass_id
,
evaluator
=
pass_evaluator
))
self
.
__gradient_machine__
.
finish
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录