Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
bcf9f421
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bcf9f421
编写于
7月 12, 2017
作者:
武
武毅
提交者:
GitHub
7月 12, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2774 from typhoonzero/fix_newupdater
Fix new remote updater for go pserver
上级
a2e5f652
5a4f33df
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
33 addition
and
10 deletion
+33
-10
go/pserver/client/c/test/test_train.py
go/pserver/client/c/test/test_train.py
+1
-1
go/pserver/optimizer.go
go/pserver/optimizer.go
+8
-6
paddle/trainer/NewRemoteParameterUpdater.cpp
paddle/trainer/NewRemoteParameterUpdater.cpp
+19
-3
paddle/trainer/NewRemoteParameterUpdater.h
paddle/trainer/NewRemoteParameterUpdater.h
+2
-0
python/paddle/v2/optimizer.py
python/paddle/v2/optimizer.py
+2
-0
python/paddle/v2/trainer.py
python/paddle/v2/trainer.py
+1
-0
未找到文件。
go/pserver/client/c/test/test_train.py
浏览文件 @
bcf9f421
...
...
@@ -19,7 +19,7 @@ def main():
# create parameters
parameters
=
paddle
.
parameters
.
create
(
cost
)
# create optimizer
# create optimizer
of new remote updater to pserver
optimizer
=
paddle
.
optimizer
.
Momentum
(
momentum
=
0
)
#TODO(zhihong) : replace optimizer with new OptimizerConfig
...
...
go/pserver/optimizer.go
浏览文件 @
bcf9f421
...
...
@@ -41,22 +41,24 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
p
:=
paramWithConfigs
.
Param
c
:=
paramWithConfigs
.
Config
s
:=
State
paramBufferSize
:=
C
.
size_t
(
len
(
p
.
Content
)
/
C
.
sizeof_float
)
log
.
WithFields
(
log
.
Fields
{
"ElementType"
:
p
.
ElementType
,
"ParamSize"
:
len
(
p
.
Content
)
,
"ParamSize"
:
paramBufferSize
,
"ConfigSize"
:
len
(
c
),
"StateSize"
:
len
(
s
),
})
.
Info
(
"New Optimizer Created with config:"
)
var
cbuffer
unsafe
.
Pointer
cbuffer
=
C
.
malloc
(
C
.
size_t
(
len
(
p
.
Content
)))
C
.
memcpy
(
cbuffer
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
C
.
size_t
(
len
(
p
.
Content
)))
cbuffer
=
C
.
malloc
(
paramBufferSize
)
C
.
memcpy
(
cbuffer
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
paramBufferSize
)
var
cstate
unsafe
.
Pointer
if
len
(
s
)
!=
0
{
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
len
(
p
.
Content
)
/
C
.
sizeof_float
),
(
*
C
.
char
)(
cstate
),
C
.
int
(
len
(
s
)))
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
paramBufferSize
),
(
*
C
.
char
)(
cstate
),
C
.
int
(
len
(
s
)))
return
o
}
...
...
@@ -68,8 +70,8 @@ func (o *optimizer) GetWeights() []byte {
func
(
o
*
optimizer
)
GetStates
()
[]
byte
{
var
cbuffer
*
C
.
char
cbuffer
_l
en
:=
C
.
paddle_optimizer_get_state
(
o
.
opt
,
&
cbuffer
)
return
cArrayToSlice
(
unsafe
.
Pointer
(
cbuffer
),
int
(
cbuffer
_l
en
))
cbuffer
L
en
:=
C
.
paddle_optimizer_get_state
(
o
.
opt
,
&
cbuffer
)
return
cArrayToSlice
(
unsafe
.
Pointer
(
cbuffer
),
int
(
cbuffer
L
en
))
}
func
(
o
*
optimizer
)
UpdateParameter
(
g
Gradient
)
error
{
...
...
paddle/trainer/NewRemoteParameterUpdater.cpp
浏览文件 @
bcf9f421
...
...
@@ -22,7 +22,8 @@ DECLARE_string(save_dir);
namespace
paddle
{
NewRemoteParameterUpdater
::
NewRemoteParameterUpdater
(
const
OptimizationConfig
&
config
,
const
std
::
string
pserverSpec
)
:
parameterClient_
(
-
1
),
:
trainerConfig_
(
config
),
parameterClient_
(
-
1
),
newParameters_
(
nullptr
),
newGradients_
(
nullptr
),
pserverSpec_
(
pserverSpec
)
{}
...
...
@@ -51,7 +52,22 @@ void NewRemoteParameterUpdater::init(
LOG
(
INFO
)
<<
"paddle_begin_init_params start"
;
for
(
int
i
=
0
;
i
<
parameterSize
();
++
i
)
{
auto
paramConfig
=
parameters_
[
i
]
->
getConfig
();
std
::
string
bytes
=
paramConfig
.
SerializeAsString
();
LOG
(
INFO
)
<<
"old param config: "
<<
paramConfig
.
DebugString
();
// FIXME(typhoonzero): convert old paramConfig to optimizerConfig
OptimizerConfig
optimizeConfigV2
;
auto
sgdConfigV2
=
optimizeConfigV2
.
mutable_sgd
();
sgdConfigV2
->
set_momentum
(
paramConfig
.
momentum
());
sgdConfigV2
->
set_decay
(
paramConfig
.
decay_rate
());
optimizeConfigV2
.
set_lr_policy
(
paddle
::
OptimizerConfig
::
Const
);
auto
constlr
=
optimizeConfigV2
.
mutable_const_lr
();
constlr
->
set_learning_rate
(
paramConfig
.
learning_rate
());
if
(
trainerConfig_
.
algorithm
()
==
"sgd"
)
{
optimizeConfigV2
.
set_optimizer
(
paddle
::
OptimizerConfig
::
SGD
);
// FIXME: config all algorithms
}
else
{
optimizeConfigV2
.
set_optimizer
(
paddle
::
OptimizerConfig
::
SGD
);
}
std
::
string
bytes
=
optimizeConfigV2
.
SerializeAsString
();
const
char
*
array
=
bytes
.
data
();
int
size
=
(
int
)
bytes
.
size
();
paddle_init_param
(
...
...
@@ -83,4 +99,4 @@ void NewRemoteParameterUpdater::finishBatch(real cost) {
void
NewRemoteParameterUpdater
::
startPass
()
{}
bool
NewRemoteParameterUpdater
::
finishPass
()
{
return
true
;
}
}
}
// namespace paddle
paddle/trainer/NewRemoteParameterUpdater.h
浏览文件 @
bcf9f421
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include <functional>
#include <thread>
#include "OptimizerConfig.pb.h"
#include "ParameterUpdater.h"
#include "libpaddle_pserver_cclient.h"
#include "paddle/pserver/ParameterClient2.h"
...
...
@@ -101,6 +102,7 @@ private:
}
protected:
const
OptimizationConfig
&
trainerConfig_
;
/// internal parameter client object for exchanging data with pserver
paddle_pserver_client
parameterClient_
;
/// the parameters for new pserver client
...
...
python/paddle/v2/optimizer.py
浏览文件 @
bcf9f421
...
...
@@ -66,6 +66,8 @@ class Optimizer(object):
if use_sparse_remote_updater:
gradient_machine.prefetch(in_args)
parameter_updater.getParametersRemote()
:param pserver_spec: pserver location, eg: localhost:3000
:return: parameter_updater
"""
if
is_local
:
...
...
python/paddle/v2/trainer.py
浏览文件 @
bcf9f421
...
...
@@ -41,6 +41,7 @@ class SGD(object):
:type parameters: paddle.v2.parameters.Parameters
:param extra_layers: Some layers in the neural network graph are not
in the path of cost layer.
:param pserver_spec: pserver location, eg: localhost:3000
:type extra_layers: paddle.v2.config_base.Layer
"""
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录