Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
6802b65c
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6802b65c
编写于
4月 14, 2017
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
init support remote updater
上级
b25c5124
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
32 addition
and
9 deletion
+32
-9
paddle/api/PaddleAPI.h
paddle/api/PaddleAPI.h
+1
-0
python/paddle/v2/topology.py
python/paddle/v2/topology.py
+10
-0
python/paddle/v2/trainer.py
python/paddle/v2/trainer.py
+21
-9
未找到文件。
paddle/api/PaddleAPI.h
浏览文件 @
6802b65c
...
...
@@ -469,6 +469,7 @@ private:
enum
GradientMatchineCreateMode
{
CREATE_MODE_NORMAL
=
0
,
CREATE_MODE_SGD_SPARSE_CPU_TRAINING
=
3
,
CREATE_MODE_TESTING
=
4
};
...
...
python/paddle/v2/topology.py
浏览文件 @
6802b65c
...
...
@@ -73,6 +73,16 @@ class Topology(object):
assert
isinstance
(
self
.
__model_config__
,
ModelConfig
)
def
use_sparse_updater
(
self
):
"""
check if any parameter require to use sparse_update
:return:
"""
for
parameter
in
self
.
__model_config__
.
parameters
:
if
parameter
.
sparse_update
or
parameter
.
sparse_remote_update
:
return
True
return
False
def
proto
(
self
):
return
self
.
__model_config__
...
...
python/paddle/v2/trainer.py
浏览文件 @
6802b65c
...
...
@@ -42,7 +42,7 @@ class SGD(object):
:type extra_layers: paddle.v2.config_base.Layer
"""
def
__init__
(
self
,
cost
,
parameters
,
update_equation
,
extra_layers
=
None
):
def
__init__
(
self
,
cost
,
parameters
,
update_equation
,
extra_layers
=
None
,
is_local
=
True
):
if
not
isinstance
(
parameters
,
v2_parameters
.
Parameters
):
raise
TypeError
(
'parameters should be parameters'
)
...
...
@@ -55,15 +55,21 @@ class SGD(object):
self
.
__topology__
=
topology
self
.
__parameters__
=
parameters
self
.
__topology_in_proto__
=
topology
.
proto
()
self
.
__is_local__
=
is_local
# In local mode, disable sparse_remote_update.
self
.
__use_sparse_updater__
=
self
.
__topology__
.
use_sparse_updater
()
# # In local mode, disable sparse_remote_update.
if
is_local
:
self
.
__use_sparse_updater__
=
False
for
param
in
self
.
__topology_in_proto__
.
parameters
:
if
param
.
sparse_remote_update
:
param
.
sparse_remote_update
=
False
self
.
__gm_create_mode__
=
api
.
CREATE_MODE_NORMAL
if
not
\
self
.
__use_sparse_updater__
else
api
.
CREATE_MODE_SGD_SPARSE_CPU_TRAINING
self
.
__data_types__
=
topology
.
data_type
()
gm
=
api
.
GradientMachine
.
createFromConfigProto
(
self
.
__topology_in_proto__
,
api
.
CREATE_MODE_NORMAL
,
self
.
__topology_in_proto__
,
self
.
__gm_create_mode__
,
self
.
__optimizer__
.
enable_types
())
assert
isinstance
(
gm
,
api
.
GradientMachine
)
self
.
__gradient_machine__
=
gm
...
...
@@ -88,7 +94,10 @@ class SGD(object):
event_handler
=
default_event_handler
__check_train_args__
(
**
locals
())
if
self
.
__is_local__
:
updater
=
self
.
__optimizer__
.
create_local_updater
()
else
:
updater
=
self
.
__optimizer__
.
create_remote_updater
(
num_passes
)
updater
.
init
(
self
.
__gradient_machine__
)
self
.
__gradient_machine__
.
start
()
...
...
@@ -108,6 +117,9 @@ class SGD(object):
v2_event
.
BeginIteration
(
pass_id
=
pass_id
,
batch_id
=
batch_id
))
pass_type
=
updater
.
startBatch
(
len
(
data_batch
))
if
self
.
__use_sparse_updater__
:
self
.
__gradient_machine__
.
prefetch
(
feeder
(
data_batch
))
updater
.
getParametersRemote
()
self
.
__gradient_machine__
.
forwardBackward
(
feeder
(
data_batch
),
out_args
,
pass_type
)
self
.
__gradient_machine__
.
eval
(
pass_evaluator
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录