Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
d1cda903
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d1cda903
编写于
8月 07, 2017
作者:
H
helinwang
提交者:
GitHub
8月 07, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3222 from helinwang/checkpoint
Fix pserver save / load checkpoint
上级
9bb57153
33fb8d7a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
128 addition
and
68 deletion
+128
-68
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+3
-3
go/glide.lock
go/glide.lock
+4
-2
go/glide.yaml
go/glide.yaml
+5
-3
go/pserver/client/client_test.go
go/pserver/client/client_test.go
+1
-1
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+8
-2
go/pserver/optimizer.go
go/pserver/optimizer.go
+2
-0
go/pserver/service.go
go/pserver/service.go
+101
-53
go/pserver/service_test.go
go/pserver/service_test.go
+4
-4
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
d1cda903
...
...
@@ -32,7 +32,7 @@ import (
func
main
()
{
port
:=
flag
.
Int
(
"port"
,
0
,
"port of the pserver"
)
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of th
is pserver, should be larger or equal than 0
"
)
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of th
e pserver, set to -1 if use etcd for auto pserver index registry
"
)
etcdEndpoint
:=
flag
.
String
(
"etcd-endpoint"
,
"http://127.0.0.1:2379"
,
"comma separated endpoint string for pserver to connect to etcd"
)
dialTimeout
:=
flag
.
Duration
(
"dial-timeout"
,
5
*
time
.
Second
,
"dial timeout"
)
...
...
@@ -60,12 +60,12 @@ func main() {
idx
,
err
=
e
.
Register
(
*
port
)
candy
.
Must
(
err
)
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
cp
,
err
=
pserver
.
LoadCheckpoint
(
e
,
idx
)
if
err
!=
nil
{
if
err
==
pserver
.
ErrCheckpointNotFound
{
log
.
Infof
(
"Could not find the pserver checkpoint."
)
}
else
{
log
.
Errorf
(
"Fetch checkpoint failed, %s"
,
err
)
panic
(
err
)
}
}
}
...
...
go/glide.lock
浏览文件 @
d1cda903
hash:
2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c
updated: 2017-0
7-29T07:34:48.722757905+08:00
hash:
1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-0
8-03T21:46:51.744995189Z
imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
...
...
@@ -145,6 +145,8 @@ imports:
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/sirupsen/logrus
version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy
...
...
go/glide.yaml
浏览文件 @
d1cda903
...
...
@@ -14,11 +14,13 @@ import:
version
:
^1.0.0
-
package
:
github.com/topicai/candy
-
package
:
golang.org/x/crypto
vcs
:
git
repo
:
https://github.com/golang/crypto.git
-
package
:
golang.org/x/sys
vcs
:
git
-
package
:
golang.org/x/sys
repo
:
https://github.com/golang/sys.git
-
package
:
golang.org/x/text
vcs
:
git
-
package
:
golang.org/x/text
repo
:
https://github.com/golang/text.git
vcs
:
git
-
package
:
github.com/satori/go.uuid
version
:
v1.1.0
go/pserver/client/client_test.go
浏览文件 @
d1cda903
...
...
@@ -59,7 +59,7 @@ func initClient() [numPserver]int {
go
func
(
l
net
.
Listener
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
if
err
!=
nil
{
panic
(
err
)
}
...
...
go/pserver/etcd_client.go
浏览文件 @
d1cda903
...
...
@@ -206,6 +206,7 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
if
err
!=
nil
{
return
[]
byte
{},
err
}
kvs
:=
resp
.
Kvs
if
len
(
kvs
)
==
0
{
return
[]
byte
{},
nil
...
...
@@ -215,9 +216,14 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
)
error
{
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
,
withLease
bool
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
_
,
err
:=
e
.
client
.
Put
(
ctx
,
key
,
string
(
value
),
clientv3
.
WithLease
(
e
.
sess
.
Lease
()))
var
err
error
if
withLease
{
_
,
err
=
e
.
client
.
Put
(
ctx
,
key
,
string
(
value
),
clientv3
.
WithLease
(
e
.
sess
.
Lease
()))
}
else
{
_
,
err
=
e
.
client
.
Put
(
ctx
,
key
,
string
(
value
))
}
cancel
()
return
err
}
...
...
go/pserver/optimizer.go
浏览文件 @
d1cda903
...
...
@@ -32,6 +32,7 @@ type optimizer struct {
opt
*
C
.
struct_paddle_optimizer
elementType
ElementType
contentLen
int
config
[]
byte
}
func
cArrayToSlice
(
p
unsafe
.
Pointer
,
len
int
)
[]
byte
{
...
...
@@ -70,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
o
.
config
=
c
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
paramBufferSize
),
(
*
C
.
char
)(
cstate
),
C
.
int
(
len
(
s
)))
return
o
...
...
go/pserver/service.go
浏览文件 @
d1cda903
...
...
@@ -25,11 +25,13 @@ import (
"fmt"
"io/ioutil"
"os"
"path
/filepath
"
"path"
"strconv"
"sync"
"time"
uuid
"github.com/satori/go.uuid"
log
"github.com/sirupsen/logrus"
)
...
...
@@ -42,9 +44,9 @@ var ErrCheckpointNotFound = errors.New("checkpoint not found")
// RPC error message.
const
(
AlreadyInitialized
=
"pserver already initialized"
Uninitialized
=
"pserver not fully initialized"
CheckpointMD5Failed
=
"checkpoint file MD5
validation failed"
AlreadyInitialized
=
"pserver already initialized"
Uninitialized
=
"pserver not fully initialized"
WrongChecksum
=
"checkpoint file checksum
validation failed"
)
// Supported element types.
...
...
@@ -73,11 +75,12 @@ type ParameterWithConfig struct {
// checkpointMeta saves checkpoint metadata
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Path
string
`json:"path"`
MD5
string
`json:"md5"`
Timestamp
int64
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
// Checkpoint is the pserver shard persist in file
.
type
Checkpoint
[]
parameterCheckpoint
// Gradient is the gradient of the parameter.
...
...
@@ -90,50 +93,58 @@ type Service struct {
checkpointInterval
time
.
Duration
checkpointPath
string
client
*
EtcdClient
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
}
// parameterCheckpoint saves parameter checkpoint
// parameterCheckpoint saves parameter checkpoint
.
type
parameterCheckpoint
struct
{
ParameterWithConfig
State
[]
byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func
NewCheckpointFromFile
(
cpPath
string
,
idx
int
,
e
*
EtcdClient
)
(
Checkpoint
,
error
)
{
v
,
err
:=
e
.
GetKey
(
PsPath
+
string
(
idx
),
3
*
time
.
Second
)
func
loadMeta
(
e
*
EtcdClient
,
idx
int
)
(
meta
checkpointMeta
,
err
error
)
{
v
,
err
:=
e
.
GetKey
(
PsCheckpoint
+
strconv
.
Itoa
(
idx
),
3
*
time
.
Second
)
if
err
!=
nil
{
return
nil
,
err
return
}
if
len
(
v
)
==
0
{
return
nil
,
ErrCheckpointNotFound
err
=
ErrCheckpointNotFound
return
}
var
cpMeta
checkpointMeta
if
err
=
json
.
Unmarshal
(
v
,
&
cpMeta
);
err
!=
nil
{
return
nil
,
err
if
err
=
json
.
Unmarshal
(
v
,
&
meta
);
err
!=
nil
{
return
}
fn
:=
filepath
.
Join
(
cpPath
,
cpMeta
.
UUID
)
if
_
,
err
=
os
.
Stat
(
fn
);
os
.
IsNotExist
(
err
)
{
return
}
// LoadCheckpoint loads checkpoint from file.
func
LoadCheckpoint
(
e
*
EtcdClient
,
idx
int
)
(
Checkpoint
,
error
)
{
cpMeta
,
err
:=
loadMeta
(
e
,
idx
)
if
err
!=
nil
{
return
nil
,
err
}
content
,
err
:=
ioutil
.
ReadFile
(
fn
)
content
,
err
:=
ioutil
.
ReadFile
(
cpMeta
.
Path
)
if
err
!=
nil
{
return
nil
,
err
}
// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
if
md5
!=
cpMeta
.
MD5
{
return
nil
,
errors
.
New
(
CheckpointMD5Failed
)
return
nil
,
errors
.
New
(
WrongChecksum
)
}
dec
:=
gob
.
NewDecoder
(
bytes
.
NewReader
(
content
))
cp
:=
Checkpoint
{}
if
err
=
dec
.
Decode
(
cp
);
err
!=
nil
{
var
cp
Checkpoint
if
err
=
dec
.
Decode
(
&
cp
);
err
!=
nil
{
return
nil
,
err
}
return
cp
,
nil
...
...
@@ -193,6 +204,15 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
}
close
(
s
.
initialized
)
go
func
()
{
t
:=
time
.
Tick
(
s
.
checkpointInterval
)
for
range
t
{
err
:=
s
.
checkpoint
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
}
}()
return
nil
}
...
...
@@ -240,23 +260,36 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return
nil
}
// pserver save checkpoint
func
(
s
*
Service
)
doCheckpoint
()
(
err
error
)
{
<-
s
.
initialized
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
func
traceTime
(
start
time
.
Time
,
name
string
)
{
elapsed
:=
time
.
Since
(
start
)
log
.
Infof
(
"%s took %v"
,
name
,
elapsed
)
}
// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func
(
s
*
Service
)
checkpoint
()
(
err
error
)
{
log
.
Infoln
(
"Begin save checkpoint."
)
defer
traceTime
(
time
.
Now
(),
"save checkpoint"
)
s
.
mu
.
Lock
()
cp
:=
make
([]
parameterCheckpoint
,
len
(
s
.
optMap
))
index
:=
0
// TODO(helin): write checkpoint incrementally to reduce memory
// footprint during checkpoint.
for
name
,
opt
:=
range
s
.
optMap
{
var
pc
parameterCheckpoint
pc
.
Param
.
Name
=
name
pc
.
Param
.
ElementType
=
opt
.
elementType
pc
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
Config
=
opt
.
config
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
}
s
.
mu
.
Unlock
()
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
=
encoder
.
Encode
(
cp
)
...
...
@@ -264,32 +297,9 @@ func (s *Service) doCheckpoint() (err error) {
return
}
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
UnixNano
()
h
:=
md5
.
New
()
cpMeta
.
MD5
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMetajson
,
err
:=
json
.
Marshal
(
cpMeta
)
if
err
!=
nil
{
return
}
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
*
time
.
Second
)
if
err
!=
nil
{
return
}
if
_
,
err
=
os
.
Stat
(
cpMeta
.
UUID
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
if
err
!=
nil
{
log
.
Infof
(
"Removing checkpoint %s failed"
,
cpMeta
.
UUID
)
}
else
{
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
id
:=
uuid
.
NewV4
()
.
String
()
p
:=
path
.
Join
(
s
.
checkpointPath
,
id
)
f
,
err
:=
os
.
Create
(
p
)
if
err
!=
nil
{
return
}
...
...
@@ -317,5 +327,43 @@ func (s *Service) doCheckpoint() (err error) {
return
}
oldMeta
,
err
:=
loadMeta
(
s
.
client
,
s
.
idx
)
if
err
==
ErrCheckpointNotFound
{
log
.
Infoln
(
"Do not have existing checkpoint."
)
err
=
nil
}
if
err
!=
nil
{
return
}
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMeta
:=
checkpointMeta
{
UUID
:
id
,
Timestamp
:
time
.
Now
()
.
UnixNano
(),
MD5
:
md5
,
Path
:
p
,
}
json
,
err
:=
json
.
Marshal
(
cpMeta
)
if
err
!=
nil
{
return
}
err
=
s
.
client
.
PutKey
(
PsCheckpoint
+
strconv
.
Itoa
(
s
.
idx
),
json
,
3
*
time
.
Second
,
false
)
if
err
!=
nil
{
return
}
if
oldMeta
.
Path
!=
""
{
rmErr
:=
os
.
Remove
(
oldMeta
.
Path
)
if
rmErr
!=
nil
{
// log error, but still treat checkpoint as
// successful.
log
.
Errorln
(
rmErr
)
}
}
return
}
go/pserver/service_test.go
浏览文件 @
d1cda903
...
...
@@ -30,7 +30,7 @@ const (
func
TestServiceFull
(
t
*
testing
.
T
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -102,7 +102,7 @@ func TestServiceFull(t *testing.T) {
func
TestMultipleInit
(
t
*
testing
.
T
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
...
...
@@ -119,7 +119,7 @@ func TestMultipleInit(t *testing.T) {
func
TestUninitialized
(
t
*
testing
.
T
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
t
.
Fatal
(
err
)
...
...
@@ -128,7 +128,7 @@ func TestUninitialized(t *testing.T) {
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录