Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
19bfb8a1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
19bfb8a1
编写于
7月 13, 2017
作者:
Y
Yancey
提交者:
GitHub
7月 13, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PServer recovery from checkpoint (#2741)
* Server recovery from checkpoint
上级
f5f7d6bd
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
121 addition
and
62 deletion
+121
-62
.gitignore
.gitignore
+3
-0
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+17
-22
go/glide.lock
go/glide.lock
+8
-6
go/glide.yaml
go/glide.yaml
+1
-0
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+19
-3
go/pserver/service.go
go/pserver/service.go
+73
-31
未找到文件。
.gitignore
浏览文件 @
19bfb8a1
...
...
@@ -22,3 +22,6 @@ cmake-build-*
# generated while compiling
python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake
go/cmd/pserver/pserver.go
浏览文件 @
19bfb8a1
...
...
@@ -8,6 +8,7 @@ import (
"time"
"github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver"
log
"github.com/sirupsen/logrus"
...
...
@@ -18,53 +19,47 @@ func main() {
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of this pserver, should be larger or equal than 0"
)
etcdEndpoint
:=
flag
.
String
(
"etcd-endpoint"
,
"http://127.0.0.1:2379"
,
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
etcdTimeout
:=
flag
.
Duration
(
"etcd-timeout"
,
5
*
time
.
Second
,
"timeout for etcd calls"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Int
(
"checkpoint-interval"
,
600
,
"save checkpoint per interval seconds"
)
checkpointInterval
:=
flag
.
Duration
(
"checkpoint-interval"
,
600
*
time
.
Second
,
"save checkpoint per interval seconds"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
level
,
err
:=
log
.
ParseLevel
(
*
logLevel
)
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
log
.
SetLevel
(
level
)
var
idx
int
var
cp
pserver
.
Checkpoint
var
cp
*
pserver
.
Checkpoint
var
e
*
pserver
.
EtcdClient
if
*
index
>=
0
{
idx
=
*
index
}
else
{
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
*
etcdTimeout
)
idx
,
err
=
e
.
Register
()
candy
.
Must
(
err
)
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
if
err
!=
nil
{
panic
(
err
)
log
.
Errorf
(
"Fetch checkpoint failed, %s"
,
err
)
}
}
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
err
=
rpc
.
Register
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
rpc
.
HandleHTTP
()
l
,
err
:=
net
.
Listen
(
"tcp"
,
":"
+
strconv
.
Itoa
(
*
port
))
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
log
.
Infof
(
"start pserver at port %d"
,
*
port
)
err
=
http
.
Serve
(
l
,
nil
)
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
}
go/glide.lock
浏览文件 @
19bfb8a1
hash:
b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7
updated: 2017-0
6-27T14:05:48.925262819
+08:00
hash:
a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
updated: 2017-0
7-11T10:04:40.786745417
+08:00
imports:
- name: github.com/coreos/etcd
version:
61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7
version:
cb2a496c4ddd1c87a9f280e116649b599999ec79
subpackages:
- auth/authpb
- clientv3
...
...
@@ -22,7 +22,9 @@ imports:
- name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129
- name: github.com/sirupsen/logrus
version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b
version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
...
...
@@ -34,11 +36,11 @@ imports:
- lex/httplex
- trace
- name: golang.org/x/sys
version:
f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733
version:
abf9c25f54453410d0c6668e519582a9e1115027
subpackages:
- unix
- name: golang.org/x/text
version:
4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77
version:
cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
subpackages:
- secure/bidirule
- transform
...
...
go/glide.yaml
浏览文件 @
19bfb8a1
...
...
@@ -10,3 +10,4 @@ import:
version
:
^1.7.4-pre
-
package
:
github.com/sirupsen/logrus
version
:
^1.0.0
-
package
:
github.com/topicai/candy
go/pserver/etcd_client.go
浏览文件 @
19bfb8a1
...
...
@@ -16,7 +16,7 @@ import (
const
(
// PsDesired is etcd path for store desired pserver count
PsDesired
=
"/ps_desired"
// Ps
Addr
is the base dir for pserver to store their addr
// Ps
Path
is the base dir for pserver to store their addr
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
...
...
@@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return
idx
,
nil
}
// GetKey gets the value by the specified key
func
(
e
*
EtcdClient
)
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
resp
,
err
:=
e
.
etcdClient
.
Get
(
ctx
,
key
)
cancel
()
if
err
!=
nil
{
return
[]
byte
{},
err
}
kvs
:=
resp
.
Kvs
if
len
(
kvs
)
==
0
{
return
[]
byte
{},
nil
}
v
:=
kvs
[
0
]
.
Value
return
v
,
nil
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
int
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
)
)
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
out
)
_
,
err
:=
e
.
etcdClient
.
Put
(
ctx
,
key
,
string
(
value
))
cancel
()
if
err
!=
nil
{
...
...
go/pserver/service.go
浏览文件 @
19bfb8a1
...
...
@@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strconv"
...
...
@@ -21,14 +22,14 @@ import (
// ElementType is the type of elements of a Parameter.
type
ElementType
int
// RPC error message.
const
(
// AlreadyInitialized is true if pserver is initialized
AlreadyInitialized
=
"pserver already initialized"
// Uninitialized is true if pserver not fully initialized
Uninitialized
=
"pserver not fully initialized"
CheckpointMD5Failed
=
"checkpoint file MD5 validation failed"
)
// Supported element types
// Supported element types
.
const
(
Int32
ElementType
=
iota
UInt32
...
...
@@ -51,21 +52,15 @@ type ParameterWithConfig struct {
Config
[]
byte
// parameter configuration in Proto Buffer format
}
// ParameterCheckpoint is Parameter and State checkpoint
type
ParameterCheckpoint
struct
{
ParamConfig
ParameterWithConfig
State
[]
byte
}
// checkpoint signature
// checkpointMeta saves checkpoint metadata
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
M
d5sum
string
`json:"md5sum
"`
Timestamp
string
`json:"timestamp"`
M
D5
string
`json:"md5
"`
Timestamp
int64
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type
Checkpoint
[]
P
arameterCheckpoint
type
Checkpoint
[]
p
arameterCheckpoint
// Gradient is the gradient of the parameter.
type
Gradient
Parameter
...
...
@@ -81,12 +76,53 @@ type Service struct {
optMap
map
[
string
]
*
optimizer
}
// parameterCheckpoint saves parameter checkpoint
type
parameterCheckpoint
struct
{
ParameterWithConfig
State
[]
byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func
NewCheckpointFromFile
(
cpPath
string
,
idx
int
,
e
*
EtcdClient
)
(
*
Checkpoint
,
error
)
{
v
,
err
:=
e
.
GetKey
(
PsPath
+
string
(
idx
),
3
*
time
.
Second
)
if
err
!=
nil
{
return
nil
,
err
}
var
cpMeta
checkpointMeta
if
err
=
json
.
Unmarshal
(
v
,
&
cpMeta
);
err
!=
nil
{
return
nil
,
err
}
fn
:=
filepath
.
Join
(
cpPath
,
cpMeta
.
UUID
)
if
_
,
err
=
os
.
Stat
(
fn
);
os
.
IsNotExist
(
err
)
{
return
nil
,
err
}
content
,
err
:=
ioutil
.
ReadFile
(
fn
)
if
err
!=
nil
{
return
nil
,
err
}
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
if
md5
!=
cpMeta
.
MD5
{
return
nil
,
errors
.
New
(
CheckpointMD5Failed
)
}
dec
:=
gob
.
NewDecoder
(
bytes
.
NewReader
(
content
))
cp
:=
&
Checkpoint
{}
if
err
=
dec
.
Decode
(
cp
);
err
!=
nil
{
return
nil
,
err
}
return
cp
,
nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func
NewService
(
idx
int
,
seconds
int
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
// endpoints specified.
It will recovery from checkpoint file if a exists a specified checkpoint.
func
NewService
(
idx
int
,
interval
time
.
Duration
,
path
string
,
client
*
EtcdClient
,
cp
*
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
idx
:
idx
,
checkpointInterval
:
time
.
Second
*
time
.
Duration
(
seconds
)
,
checkpointInterval
:
interval
,
checkpointPath
:
path
,
client
:
client
,
}
...
...
@@ -94,10 +130,12 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp
s
.
initialized
=
make
(
chan
struct
{})
if
cp
!=
nil
{
for
_
,
item
:=
range
cp
{
p
:=
item
.
ParamConfig
st
:=
item
.
State
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
st
)
for
_
,
item
:=
range
*
cp
{
p
:=
ParameterWithConfig
{
Param
:
item
.
Param
,
Config
:
item
.
Config
,
}
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
item
.
State
)
}
}
return
s
,
nil
...
...
@@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error {
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
cp
:=
make
([]
ParameterCheckpoint
,
0
,
len
(
s
.
optMap
))
cp
:=
make
([]
parameterCheckpoint
,
len
(
s
.
optMap
))
index
:=
0
for
name
,
opt
:=
range
s
.
optMap
{
var
pc
P
arameterCheckpoint
pc
.
Param
Config
.
Param
.
Name
=
name
pc
.
Param
Config
.
Param
.
ElementType
=
opt
.
elementType
pc
.
Param
Config
.
Param
.
Content
=
opt
.
GetWeights
()
var
pc
p
arameterCheckpoint
pc
.
Param
.
Name
=
name
pc
.
Param
.
ElementType
=
opt
.
elementType
pc
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
...
...
@@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error {
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
String
()
cpMeta
.
Timestamp
=
time
.
Now
()
.
UnixNano
()
h
:=
md5
.
New
()
cpMeta
.
M
d5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMeta
.
M
D5
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMetajson
,
_
:=
json
.
Marshal
(
cpMeta
)
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
)
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
*
time
.
Second
)
if
err
!=
nil
{
return
err
}
...
...
@@ -219,8 +257,12 @@ func (s *Service) doCheckpoint() error {
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
if
err
!=
nil
{
log
.
Infof
(
"Removing checkpoint %s failed"
,
cpMeta
.
UUID
)
}
else
{
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
defer
f
.
Close
()
if
err
!=
nil
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录