Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
d1cda903
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d1cda903
编写于
8月 07, 2017
作者:
H
helinwang
提交者:
GitHub
8月 07, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3222 from helinwang/checkpoint
Fix pserver save / load checkpoint
上级
9bb57153
33fb8d7a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
128 addition
and
68 deletion
+128
-68
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+3
-3
go/glide.lock
go/glide.lock
+4
-2
go/glide.yaml
go/glide.yaml
+5
-3
go/pserver/client/client_test.go
go/pserver/client/client_test.go
+1
-1
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+8
-2
go/pserver/optimizer.go
go/pserver/optimizer.go
+2
-0
go/pserver/service.go
go/pserver/service.go
+101
-53
go/pserver/service_test.go
go/pserver/service_test.go
+4
-4
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
d1cda903
...
...
@@ -32,7 +32,7 @@ import (
func
main
()
{
port
:=
flag
.
Int
(
"port"
,
0
,
"port of the pserver"
)
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of th
is pserver, should be larger or equal than 0
"
)
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of th
e pserver, set to -1 if use etcd for auto pserver index registry
"
)
etcdEndpoint
:=
flag
.
String
(
"etcd-endpoint"
,
"http://127.0.0.1:2379"
,
"comma separated endpoint string for pserver to connect to etcd"
)
dialTimeout
:=
flag
.
Duration
(
"dial-timeout"
,
5
*
time
.
Second
,
"dial timeout"
)
...
...
@@ -60,12 +60,12 @@ func main() {
idx
,
err
=
e
.
Register
(
*
port
)
candy
.
Must
(
err
)
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
cp
,
err
=
pserver
.
LoadCheckpoint
(
e
,
idx
)
if
err
!=
nil
{
if
err
==
pserver
.
ErrCheckpointNotFound
{
log
.
Infof
(
"Could not find the pserver checkpoint."
)
}
else
{
log
.
Errorf
(
"Fetch checkpoint failed, %s"
,
err
)
panic
(
err
)
}
}
}
...
...
go/glide.lock
浏览文件 @
d1cda903
hash:
2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c
updated: 2017-0
7-29T07:34:48.722757905+08:00
hash:
1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-0
8-03T21:46:51.744995189Z
imports:
- name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
...
...
@@ -145,6 +145,8 @@ imports:
version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages:
- xfs
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/sirupsen/logrus
version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy
...
...
go/glide.yaml
浏览文件 @
d1cda903
...
...
@@ -14,11 +14,13 @@ import:
version
:
^1.0.0
-
package
:
github.com/topicai/candy
-
package
:
golang.org/x/crypto
vcs
:
git
repo
:
https://github.com/golang/crypto.git
-
package
:
golang.org/x/sys
vcs
:
git
-
package
:
golang.org/x/sys
repo
:
https://github.com/golang/sys.git
-
package
:
golang.org/x/text
vcs
:
git
-
package
:
golang.org/x/text
repo
:
https://github.com/golang/text.git
vcs
:
git
-
package
:
github.com/satori/go.uuid
version
:
v1.1.0
go/pserver/client/client_test.go
浏览文件 @
d1cda903
...
...
@@ -59,7 +59,7 @@ func initClient() [numPserver]int {
go
func
(
l
net
.
Listener
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
if
err
!=
nil
{
panic
(
err
)
}
...
...
go/pserver/etcd_client.go
浏览文件 @
d1cda903
...
...
@@ -206,6 +206,7 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
if
err
!=
nil
{
return
[]
byte
{},
err
}
kvs
:=
resp
.
Kvs
if
len
(
kvs
)
==
0
{
return
[]
byte
{},
nil
...
...
@@ -215,9 +216,14 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
)
error
{
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
,
withLease
bool
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
_
,
err
:=
e
.
client
.
Put
(
ctx
,
key
,
string
(
value
),
clientv3
.
WithLease
(
e
.
sess
.
Lease
()))
var
err
error
if
withLease
{
_
,
err
=
e
.
client
.
Put
(
ctx
,
key
,
string
(
value
),
clientv3
.
WithLease
(
e
.
sess
.
Lease
()))
}
else
{
_
,
err
=
e
.
client
.
Put
(
ctx
,
key
,
string
(
value
))
}
cancel
()
return
err
}
...
...
go/pserver/optimizer.go
浏览文件 @
d1cda903
...
...
@@ -32,6 +32,7 @@ type optimizer struct {
opt
*
C
.
struct_paddle_optimizer
elementType
ElementType
contentLen
int
config
[]
byte
}
func
cArrayToSlice
(
p
unsafe
.
Pointer
,
len
int
)
[]
byte
{
...
...
@@ -70,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
o
.
config
=
c
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
paramBufferSize
),
(
*
C
.
char
)(
cstate
),
C
.
int
(
len
(
s
)))
return
o
...
...
go/pserver/service.go
浏览文件 @
d1cda903
...
...
@@ -25,11 +25,13 @@ import (
"fmt"
"io/ioutil"
"os"
"path
/filepath
"
"path"
"strconv"
"sync"
"time"
uuid
"github.com/satori/go.uuid"
log
"github.com/sirupsen/logrus"
)
...
...
@@ -42,9 +44,9 @@ var ErrCheckpointNotFound = errors.New("checkpoint not found")
// RPC error message.
const
(
AlreadyInitialized
=
"pserver already initialized"
Uninitialized
=
"pserver not fully initialized"
CheckpointMD5Failed
=
"checkpoint file MD5
validation failed"
AlreadyInitialized
=
"pserver already initialized"
Uninitialized
=
"pserver not fully initialized"
WrongChecksum
=
"checkpoint file checksum
validation failed"
)
// Supported element types.
...
...
@@ -73,11 +75,12 @@ type ParameterWithConfig struct {
// checkpointMeta saves checkpoint metadata
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Path
string
`json:"path"`
MD5
string
`json:"md5"`
Timestamp
int64
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
// Checkpoint is the pserver shard persist in file
.
type
Checkpoint
[]
parameterCheckpoint
// Gradient is the gradient of the parameter.
...
...
@@ -90,50 +93,58 @@ type Service struct {
checkpointInterval
time
.
Duration
checkpointPath
string
client
*
EtcdClient
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
}
// parameterCheckpoint saves parameter checkpoint
// parameterCheckpoint saves parameter checkpoint
.
type
parameterCheckpoint
struct
{
ParameterWithConfig
State
[]
byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func
NewCheckpointFromFile
(
cpPath
string
,
idx
int
,
e
*
EtcdClient
)
(
Checkpoint
,
error
)
{
v
,
err
:=
e
.
GetKey
(
PsPath
+
string
(
idx
),
3
*
time
.
Second
)
func
loadMeta
(
e
*
EtcdClient
,
idx
int
)
(
meta
checkpointMeta
,
err
error
)
{
v
,
err
:=
e
.
GetKey
(
PsCheckpoint
+
strconv
.
Itoa
(
idx
),
3
*
time
.
Second
)
if
err
!=
nil
{
return
nil
,
err
return
}
if
len
(
v
)
==
0
{
return
nil
,
ErrCheckpointNotFound
err
=
ErrCheckpointNotFound
return
}
var
cpMeta
checkpointMeta
if
err
=
json
.
Unmarshal
(
v
,
&
cpMeta
);
err
!=
nil
{
return
nil
,
err
if
err
=
json
.
Unmarshal
(
v
,
&
meta
);
err
!=
nil
{
return
}
fn
:=
filepath
.
Join
(
cpPath
,
cpMeta
.
UUID
)
if
_
,
err
=
os
.
Stat
(
fn
);
os
.
IsNotExist
(
err
)
{
return
}
// LoadCheckpoint loads checkpoint from file.
func
LoadCheckpoint
(
e
*
EtcdClient
,
idx
int
)
(
Checkpoint
,
error
)
{
cpMeta
,
err
:=
loadMeta
(
e
,
idx
)
if
err
!=
nil
{
return
nil
,
err
}
content
,
err
:=
ioutil
.
ReadFile
(
fn
)
content
,
err
:=
ioutil
.
ReadFile
(
cpMeta
.
Path
)
if
err
!=
nil
{
return
nil
,
err
}
// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
if
md5
!=
cpMeta
.
MD5
{
return
nil
,
errors
.
New
(
CheckpointMD5Failed
)
return
nil
,
errors
.
New
(
WrongChecksum
)
}
dec
:=
gob
.
NewDecoder
(
bytes
.
NewReader
(
content
))
cp
:=
Checkpoint
{}
if
err
=
dec
.
Decode
(
cp
);
err
!=
nil
{
var
cp
Checkpoint
if
err
=
dec
.
Decode
(
&
cp
);
err
!=
nil
{
return
nil
,
err
}
return
cp
,
nil
...
...
@@ -193,6 +204,15 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
}
close
(
s
.
initialized
)
go
func
()
{
t
:=
time
.
Tick
(
s
.
checkpointInterval
)
for
range
t
{
err
:=
s
.
checkpoint
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
}
}()
return
nil
}
...
...
@@ -240,23 +260,36 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return
nil
}
// pserver save checkpoint
func
(
s
*
Service
)
doCheckpoint
()
(
err
error
)
{
<-
s
.
initialized
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
func
traceTime
(
start
time
.
Time
,
name
string
)
{
elapsed
:=
time
.
Since
(
start
)
log
.
Infof
(
"%s took %v"
,
name
,
elapsed
)
}
// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func
(
s
*
Service
)
checkpoint
()
(
err
error
)
{
log
.
Infoln
(
"Begin save checkpoint."
)
defer
traceTime
(
time
.
Now
(),
"save checkpoint"
)
s
.
mu
.
Lock
()
cp
:=
make
([]
parameterCheckpoint
,
len
(
s
.
optMap
))
index
:=
0
// TODO(helin): write checkpoint incrementally to reduce memory
// footprint during checkpoint.
for
name
,
opt
:=
range
s
.
optMap
{
var
pc
parameterCheckpoint
pc
.
Param
.
Name
=
name
pc
.
Param
.
ElementType
=
opt
.
elementType
pc
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
Config
=
opt
.
config
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
}
s
.
mu
.
Unlock
()
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
=
encoder
.
Encode
(
cp
)
...
...
@@ -264,32 +297,9 @@ func (s *Service) doCheckpoint() (err error) {
return
}
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
UnixNano
()
h
:=
md5
.
New
()
cpMeta
.
MD5
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMetajson
,
err
:=
json
.
Marshal
(
cpMeta
)
if
err
!=
nil
{
return
}
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
*
time
.
Second
)
if
err
!=
nil
{
return
}
if
_
,
err
=
os
.
Stat
(
cpMeta
.
UUID
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
if
err
!=
nil
{
log
.
Infof
(
"Removing checkpoint %s failed"
,
cpMeta
.
UUID
)
}
else
{
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
id
:=
uuid
.
NewV4
()
.
String
()
p
:=
path
.
Join
(
s
.
checkpointPath
,
id
)
f
,
err
:=
os
.
Create
(
p
)
if
err
!=
nil
{
return
}
...
...
@@ -317,5 +327,43 @@ func (s *Service) doCheckpoint() (err error) {
return
}
oldMeta
,
err
:=
loadMeta
(
s
.
client
,
s
.
idx
)
if
err
==
ErrCheckpointNotFound
{
log
.
Infoln
(
"Do not have existing checkpoint."
)
err
=
nil
}
if
err
!=
nil
{
return
}
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMeta
:=
checkpointMeta
{
UUID
:
id
,
Timestamp
:
time
.
Now
()
.
UnixNano
(),
MD5
:
md5
,
Path
:
p
,
}
json
,
err
:=
json
.
Marshal
(
cpMeta
)
if
err
!=
nil
{
return
}
err
=
s
.
client
.
PutKey
(
PsCheckpoint
+
strconv
.
Itoa
(
s
.
idx
),
json
,
3
*
time
.
Second
,
false
)
if
err
!=
nil
{
return
}
if
oldMeta
.
Path
!=
""
{
rmErr
:=
os
.
Remove
(
oldMeta
.
Path
)
if
rmErr
!=
nil
{
// log error, but still treat checkpoint as
// successful.
log
.
Errorln
(
rmErr
)
}
}
return
}
go/pserver/service_test.go
浏览文件 @
d1cda903
...
...
@@ -30,7 +30,7 @@ const (
func
TestServiceFull
(
t
*
testing
.
T
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -102,7 +102,7 @@ func TestServiceFull(t *testing.T) {
func
TestMultipleInit
(
t
*
testing
.
T
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
...
...
@@ -119,7 +119,7 @@ func TestMultipleInit(t *testing.T) {
func
TestUninitialized
(
t
*
testing
.
T
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
t
.
Fatal
(
err
)
...
...
@@ -128,7 +128,7 @@ func TestUninitialized(t *testing.T) {
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
s
,
err
:=
pserver
.
NewService
(
0
,
time
.
Hour
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录