Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
15f021a9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
15f021a9
编写于
7月 11, 2017
作者:
D
dzhwinter
提交者:
GitHub
7月 11, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2716 from dzhwinter/save_state
Pserver Save state
上级
8c615e8f
e8296ff2
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
146 addition
and
31 deletion
+146
-31
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+6
-2
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+13
-0
go/pserver/optimizer.go
go/pserver/optimizer.go
+15
-3
go/pserver/optimizer_test.go
go/pserver/optimizer_test.go
+1
-1
go/pserver/service.go
go/pserver/service.go
+99
-11
go/pserver/service_test.go
go/pserver/service_test.go
+12
-14
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
15f021a9
...
@@ -20,6 +20,8 @@ func main() {
...
@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd"
)
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Int
(
"checkpoint-interval"
,
600
,
"save checkpoint per interval seconds"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
flag
.
Parse
()
...
@@ -31,18 +33,20 @@ func main() {
...
@@ -31,18 +33,20 @@ func main() {
log
.
SetLevel
(
level
)
log
.
SetLevel
(
level
)
var
idx
int
var
idx
int
var
cp
pserver
.
Checkpoint
var
e
*
pserver
.
EtcdClient
if
*
index
>=
0
{
if
*
index
>=
0
{
idx
=
*
index
idx
=
*
index
}
else
{
}
else
{
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
e
:
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
idx
,
err
=
e
.
Register
()
idx
,
err
=
e
.
Register
()
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
}
}
s
,
err
:=
pserver
.
NewService
(
idx
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
...
...
go/pserver/etcd_client.go
浏览文件 @
15f021a9
...
@@ -18,6 +18,8 @@ const (
...
@@ -18,6 +18,8 @@ const (
PsDesired
=
"/ps_desired"
PsDesired
=
"/ps_desired"
// PsAddr is the base dir for pserver to store their addr
// PsAddr is the base dir for pserver to store their addr
PsPath
=
"/ps/"
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
)
)
// EtcdClient is the etcd client that the pserver uses for fault
// EtcdClient is the etcd client that the pserver uses for fault
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return
idx
,
nil
return
idx
,
nil
}
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
int
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
))
_
,
err
:=
e
.
etcdClient
.
Put
(
ctx
,
key
,
string
(
value
))
cancel
()
if
err
!=
nil
{
return
err
}
return
nil
}
go/pserver/optimizer.go
浏览文件 @
15f021a9
...
@@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
...
@@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
}
}
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
)
*
optimizer
{
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
,
State
[]
byte
)
*
optimizer
{
o
:=
&
optimizer
{}
o
:=
&
optimizer
{}
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
p
:=
paramWithConfigs
.
Param
p
:=
paramWithConfigs
.
Param
c
:=
paramWithConfigs
.
Config
c
:=
paramWithConfigs
.
Config
s
:=
State
log
.
WithFields
(
log
.
Fields
{
log
.
WithFields
(
log
.
Fields
{
"ElementType"
:
p
.
ElementType
,
"ElementType"
:
p
.
ElementType
,
"ParamSize"
:
len
(
p
.
Content
),
"ParamSize"
:
len
(
p
.
Content
),
"ConfigSize"
:
len
(
c
),
"ConfigSize"
:
len
(
c
),
"StateSize"
:
len
(
s
),
})
.
Info
(
"New Optimizer Created with config:"
)
})
.
Info
(
"New Optimizer Created with config:"
)
var
cbuffer
unsafe
.
Pointer
var
cbuffer
unsafe
.
Pointer
cbuffer
=
C
.
malloc
(
C
.
size_t
(
len
(
p
.
Content
)))
cbuffer
=
C
.
malloc
(
C
.
size_t
(
len
(
p
.
Content
)))
C
.
memcpy
(
cbuffer
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
C
.
size_t
(
len
(
p
.
Content
)))
C
.
memcpy
(
cbuffer
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
C
.
size_t
(
len
(
p
.
Content
)))
var
cstate
unsafe
.
Pointer
if
len
(
s
)
!=
0
{
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
len
(
p
.
Content
)
/
C
.
sizeof_float
),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
len
(
p
.
Content
)
/
C
.
sizeof_float
),
(
*
C
.
char
)(
cstate
),
C
.
int
(
len
(
s
)))
(
*
C
.
char
)(
nullPtr
),
0
)
return
o
return
o
}
}
...
@@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
...
@@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
return
cArrayToSlice
(
buffer
,
int
(
bufferLen
)
*
C
.
sizeof_float
)
return
cArrayToSlice
(
buffer
,
int
(
bufferLen
)
*
C
.
sizeof_float
)
}
}
func
(
o
*
optimizer
)
GetStates
()
[]
byte
{
var
cbuffer
*
C
.
char
cbuffer_len
:=
C
.
paddle_optimizer_get_state
(
o
.
opt
,
&
cbuffer
)
return
cArrayToSlice
(
unsafe
.
Pointer
(
cbuffer
),
int
(
cbuffer_len
))
}
func
(
o
*
optimizer
)
UpdateParameter
(
g
Gradient
)
error
{
func
(
o
*
optimizer
)
UpdateParameter
(
g
Gradient
)
error
{
if
o
.
elementType
!=
g
.
ElementType
{
if
o
.
elementType
!=
g
.
ElementType
{
return
fmt
.
Errorf
(
"Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v"
,
g
.
Name
,
o
.
elementType
,
g
.
ElementType
)
return
fmt
.
Errorf
(
"Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v"
,
g
.
Name
,
o
.
elementType
,
g
.
ElementType
)
...
...
go/pserver/optimizer_test.go
浏览文件 @
15f021a9
...
@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
...
@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param
:
p
,
Param
:
p
,
Config
:
config
,
Config
:
config
,
}
}
o
:=
newOptimizer
(
param
)
o
:=
newOptimizer
(
param
,
nil
)
o
.
Cleanup
()
o
.
Cleanup
()
}
}
go/pserver/service.go
浏览文件 @
15f021a9
package
pserver
package
pserver
import
(
import
(
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"errors"
"fmt"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"sync"
"time"
log
"github.com/sirupsen/logrus"
)
)
// ElementType is the type of elements of a Parameter.
// ElementType is the type of elements of a Parameter.
...
@@ -39,26 +51,55 @@ type ParameterWithConfig struct {
...
@@ -39,26 +51,55 @@ type ParameterWithConfig struct {
Config
[]
byte
// parameter configuration in Proto Buffer format
Config
[]
byte
// parameter configuration in Proto Buffer format
}
}
// ParameterCheckpoint is Parameter and State checkpoint
type
ParameterCheckpoint
struct
{
ParamConfig
ParameterWithConfig
State
[]
byte
}
// checkpoint signature
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Md5sum
string
`json:"md5sum"`
Timestamp
string
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type
Checkpoint
[]
ParameterCheckpoint
// Gradient is the gradient of the parameter.
// Gradient is the gradient of the parameter.
type
Gradient
Parameter
type
Gradient
Parameter
// Service is the RPC service for pserver.
// Service is the RPC service for pserver.
type
Service
struct
{
type
Service
struct
{
initialized
chan
struct
{}
initialized
chan
struct
{}
idx
int
idx
int
checkpointInterval
time
.
Duration
mu
sync
.
Mutex
checkpointPath
string
optMap
map
[
string
]
*
optimizer
client
*
EtcdClient
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
}
}
// NewService creates a new service, will bypass etcd registration if no
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
// endpoints specified.
func
NewService
(
idx
int
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
seconds
int
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
s
:=
&
Service
{
idx
:
idx
,
idx
:
idx
,
checkpointInterval
:
time
.
Second
*
time
.
Duration
(
seconds
),
checkpointPath
:
path
,
client
:
client
,
}
}
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
initialized
=
make
(
chan
struct
{})
s
.
initialized
=
make
(
chan
struct
{})
if
cp
!=
nil
{
for
_
,
item
:=
range
cp
{
p
:=
item
.
ParamConfig
st
:=
item
.
State
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
st
)
}
}
return
s
,
nil
return
s
,
nil
}
}
...
@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
...
@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// properly memory aligned, if not, make copy to a memory
// aligned region.
// aligned region.
s
.
optMap
[
paramWithConfigs
.
Param
.
Name
]
=
newOptimizer
(
paramWithConfigs
)
s
.
optMap
[
paramWithConfigs
.
Param
.
Name
]
=
newOptimizer
(
paramWithConfigs
,
nil
)
return
nil
return
nil
}
}
...
@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
...
@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return
nil
return
nil
}
}
//
Save tells the parameter server to save parameters.
//
pserver save checkpoint
func
(
s
*
Service
)
Save
(
path
string
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
doCheckpoint
(
)
error
{
<-
s
.
initialized
<-
s
.
initialized
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
cp
:=
make
([]
ParameterCheckpoint
,
0
,
len
(
s
.
optMap
))
index
:=
0
for
name
,
opt
:=
range
s
.
optMap
{
var
pc
ParameterCheckpoint
pc
.
ParamConfig
.
Param
.
Name
=
name
pc
.
ParamConfig
.
Param
.
ElementType
=
opt
.
elementType
pc
.
ParamConfig
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
}
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
:=
encoder
.
Encode
(
cp
)
if
err
!=
nil
{
return
err
}
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
String
()
h
:=
md5
.
New
()
cpMeta
.
Md5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
// TODO
cpMetajson
,
_
:=
json
.
Marshal
(
cpMeta
)
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
)
if
err
!=
nil
{
return
err
}
if
_
,
err
=
os
.
Stat
(
cpMeta
.
UUID
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
defer
f
.
Close
()
if
err
!=
nil
{
return
err
}
writer
:=
bufio
.
NewWriter
(
f
)
_
,
err
=
writer
.
Write
(
buf
.
Bytes
())
writer
.
Flush
()
if
err
!=
nil
{
return
err
}
return
nil
return
nil
}
}
go/pserver/service_test.go
浏览文件 @
15f021a9
...
@@ -15,7 +15,8 @@ const (
...
@@ -15,7 +15,8 @@ const (
)
)
func
TestServiceFull
(
t
*
testing
.
T
)
{
func
TestServiceFull
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
}
}
...
@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
...
@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
}
}
func
TestMultipleInit
(
t
*
testing
.
T
)
{
func
TestMultipleInit
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
}
}
...
@@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
...
@@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
}
}
func
TestUninitialized
(
t
*
testing
.
T
)
{
func
TestUninitialized
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
t
.
FailNow
()
t
.
FailNow
()
...
@@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
...
@@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
}
}
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
}
}
...
@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch
<-
struct
{}{}
ch
<-
struct
{}{}
}()
}()
wg
.
Add
(
1
)
go
func
()
{
err
:=
s
.
Save
(
""
,
nil
)
if
err
!=
nil
{
errCh
<-
err
}
wg
.
Done
()
ch
<-
struct
{}{}
}()
time
.
Sleep
(
50
*
time
.
Millisecond
)
time
.
Sleep
(
50
*
time
.
Millisecond
)
select
{
select
{
...
@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Wait
()
wg
.
Wait
()
}
}
func
TestCheckpointSpeed
(
t
*
testing
.
T
)
{
//TODO(zhihong): test speed
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录