Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
19bfb8a1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
19bfb8a1
编写于
7月 13, 2017
作者:
Y
Yancey
提交者:
GitHub
7月 13, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PServer recovery from checkpoint (#2741)
* Server recovery from checkpoint
上级
f5f7d6bd
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
121 addition
and
62 deletion
+121
-62
.gitignore
.gitignore
+3
-0
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+17
-22
go/glide.lock
go/glide.lock
+8
-6
go/glide.yaml
go/glide.yaml
+1
-0
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+19
-3
go/pserver/service.go
go/pserver/service.go
+73
-31
未找到文件。
.gitignore
浏览文件 @
19bfb8a1
...
...
@@ -22,3 +22,6 @@ cmake-build-*
# generated while compiling
python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake
go/cmd/pserver/pserver.go
浏览文件 @
19bfb8a1
...
...
@@ -8,6 +8,7 @@ import (
"time"
"github.com/namsral/flag"
"github.com/topicai/candy"
"github.com/PaddlePaddle/Paddle/go/pserver"
log
"github.com/sirupsen/logrus"
...
...
@@ -18,53 +19,47 @@ func main() {
index
:=
flag
.
Int
(
"index"
,
-
1
,
"index of this pserver, should be larger or equal than 0"
)
etcdEndpoint
:=
flag
.
String
(
"etcd-endpoint"
,
"http://127.0.0.1:2379"
,
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
etcdTimeout
:=
flag
.
Duration
(
"etcd-timeout"
,
5
*
time
.
Second
,
"timeout for etcd calls"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Int
(
"checkpoint-interval"
,
600
,
"save checkpoint per interval seconds"
)
checkpointInterval
:=
flag
.
Duration
(
"checkpoint-interval"
,
600
*
time
.
Second
,
"save checkpoint per interval seconds"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
level
,
err
:=
log
.
ParseLevel
(
*
logLevel
)
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
log
.
SetLevel
(
level
)
var
idx
int
var
cp
pserver
.
Checkpoint
var
cp
*
pserver
.
Checkpoint
var
e
*
pserver
.
EtcdClient
if
*
index
>=
0
{
idx
=
*
index
}
else
{
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
*
etcdTimeout
)
idx
,
err
=
e
.
Register
()
candy
.
Must
(
err
)
cp
,
err
=
pserver
.
NewCheckpointFromFile
(
*
checkpointPath
,
idx
,
e
)
if
err
!=
nil
{
panic
(
err
)
log
.
Errorf
(
"Fetch checkpoint failed, %s"
,
err
)
}
}
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
err
=
rpc
.
Register
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
rpc
.
HandleHTTP
()
l
,
err
:=
net
.
Listen
(
"tcp"
,
":"
+
strconv
.
Itoa
(
*
port
))
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
log
.
Infof
(
"start pserver at port %d"
,
*
port
)
err
=
http
.
Serve
(
l
,
nil
)
if
err
!=
nil
{
panic
(
err
)
}
candy
.
Must
(
err
)
}
go/glide.lock
浏览文件 @
19bfb8a1
hash:
b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7
updated: 2017-0
6-27T14:05:48.925262819
+08:00
hash:
a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855
updated: 2017-0
7-11T10:04:40.786745417
+08:00
imports:
- name: github.com/coreos/etcd
version:
61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7
version:
cb2a496c4ddd1c87a9f280e116649b599999ec79
subpackages:
- auth/authpb
- clientv3
...
...
@@ -22,7 +22,9 @@ imports:
- name: github.com/PaddlePaddle/recordio
version: edfb82af0739c84f241c87390ec5649c7b28c129
- name: github.com/sirupsen/logrus
version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b
version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1
- name: github.com/topicai/candy
version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc
- name: golang.org/x/net
version: c8c74377599bd978aee1cf3b9b63a8634051cec2
subpackages:
...
...
@@ -34,11 +36,11 @@ imports:
- lex/httplex
- trace
- name: golang.org/x/sys
version:
f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733
version:
abf9c25f54453410d0c6668e519582a9e1115027
subpackages:
- unix
- name: golang.org/x/text
version:
4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77
version:
cfdf022e86b4ecfb646e1efbd7db175dd623a8fa
subpackages:
- secure/bidirule
- transform
...
...
go/glide.yaml
浏览文件 @
19bfb8a1
...
...
@@ -10,3 +10,4 @@ import:
version
:
^1.7.4-pre
-
package
:
github.com/sirupsen/logrus
version
:
^1.0.0
-
package
:
github.com/topicai/candy
go/pserver/etcd_client.go
浏览文件 @
19bfb8a1
...
...
@@ -16,7 +16,7 @@ import (
const
(
// PsDesired is etcd path for store desired pserver count
PsDesired
=
"/ps_desired"
// Ps
Addr
is the base dir for pserver to store their addr
// Ps
Path
is the base dir for pserver to store their addr
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
...
...
@@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return
idx
,
nil
}
// GetKey gets the value by the specified key
func
(
e
*
EtcdClient
)
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
timeout
)
resp
,
err
:=
e
.
etcdClient
.
Get
(
ctx
,
key
)
cancel
()
if
err
!=
nil
{
return
[]
byte
{},
err
}
kvs
:=
resp
.
Kvs
if
len
(
kvs
)
==
0
{
return
[]
byte
{},
nil
}
v
:=
kvs
[
0
]
.
Value
return
v
,
nil
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
int
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
)
)
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
out
)
_
,
err
:=
e
.
etcdClient
.
Put
(
ctx
,
key
,
string
(
value
))
cancel
()
if
err
!=
nil
{
...
...
go/pserver/service.go
浏览文件 @
19bfb8a1
...
...
@@ -9,6 +9,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"strconv"
...
...
@@ -21,14 +22,14 @@ import (
// ElementType is the type of elements of a Parameter.
type
ElementType
int
// RPC error message.
const
(
// AlreadyInitialized is true if pserver is initialized
AlreadyInitialized
=
"pserver already initialized"
// Uninitialized is true if pserver not fully initialized
Uninitialized
=
"pserver not fully initialized"
CheckpointMD5Failed
=
"checkpoint file MD5 validation failed"
)
// Supported element types
// Supported element types
.
const
(
Int32
ElementType
=
iota
UInt32
...
...
@@ -51,21 +52,15 @@ type ParameterWithConfig struct {
Config
[]
byte
// parameter configuration in Proto Buffer format
}
// ParameterCheckpoint is Parameter and State checkpoint
type
ParameterCheckpoint
struct
{
ParamConfig
ParameterWithConfig
State
[]
byte
}
// checkpoint signature
// checkpointMeta saves checkpoint metadata
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
M
d5sum
string
`json:"md5sum
"`
Timestamp
string
`json:"timestamp"`
M
D5
string
`json:"md5
"`
Timestamp
int64
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type
Checkpoint
[]
P
arameterCheckpoint
type
Checkpoint
[]
p
arameterCheckpoint
// Gradient is the gradient of the parameter.
type
Gradient
Parameter
...
...
@@ -81,12 +76,53 @@ type Service struct {
optMap
map
[
string
]
*
optimizer
}
// parameterCheckpoint saves parameter checkpoint
type
parameterCheckpoint
struct
{
ParameterWithConfig
State
[]
byte
}
// NewCheckpointFromFile loads parameters and state from checkpoint file
func
NewCheckpointFromFile
(
cpPath
string
,
idx
int
,
e
*
EtcdClient
)
(
*
Checkpoint
,
error
)
{
v
,
err
:=
e
.
GetKey
(
PsPath
+
string
(
idx
),
3
*
time
.
Second
)
if
err
!=
nil
{
return
nil
,
err
}
var
cpMeta
checkpointMeta
if
err
=
json
.
Unmarshal
(
v
,
&
cpMeta
);
err
!=
nil
{
return
nil
,
err
}
fn
:=
filepath
.
Join
(
cpPath
,
cpMeta
.
UUID
)
if
_
,
err
=
os
.
Stat
(
fn
);
os
.
IsNotExist
(
err
)
{
return
nil
,
err
}
content
,
err
:=
ioutil
.
ReadFile
(
fn
)
if
err
!=
nil
{
return
nil
,
err
}
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
if
md5
!=
cpMeta
.
MD5
{
return
nil
,
errors
.
New
(
CheckpointMD5Failed
)
}
dec
:=
gob
.
NewDecoder
(
bytes
.
NewReader
(
content
))
cp
:=
&
Checkpoint
{}
if
err
=
dec
.
Decode
(
cp
);
err
!=
nil
{
return
nil
,
err
}
return
cp
,
nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func
NewService
(
idx
int
,
seconds
int
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
// endpoints specified.
It will recovery from checkpoint file if a exists a specified checkpoint.
func
NewService
(
idx
int
,
interval
time
.
Duration
,
path
string
,
client
*
EtcdClient
,
cp
*
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
idx
:
idx
,
checkpointInterval
:
time
.
Second
*
time
.
Duration
(
seconds
)
,
checkpointInterval
:
interval
,
checkpointPath
:
path
,
client
:
client
,
}
...
...
@@ -94,10 +130,12 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp
s
.
initialized
=
make
(
chan
struct
{})
if
cp
!=
nil
{
for
_
,
item
:=
range
cp
{
p
:=
item
.
ParamConfig
st
:=
item
.
State
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
st
)
for
_
,
item
:=
range
*
cp
{
p
:=
ParameterWithConfig
{
Param
:
item
.
Param
,
Config
:
item
.
Config
,
}
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
item
.
State
)
}
}
return
s
,
nil
...
...
@@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error {
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
cp
:=
make
([]
ParameterCheckpoint
,
0
,
len
(
s
.
optMap
))
cp
:=
make
([]
parameterCheckpoint
,
len
(
s
.
optMap
))
index
:=
0
for
name
,
opt
:=
range
s
.
optMap
{
var
pc
P
arameterCheckpoint
pc
.
Param
Config
.
Param
.
Name
=
name
pc
.
Param
Config
.
Param
.
ElementType
=
opt
.
elementType
pc
.
Param
Config
.
Param
.
Content
=
opt
.
GetWeights
()
var
pc
p
arameterCheckpoint
pc
.
Param
.
Name
=
name
pc
.
Param
.
ElementType
=
opt
.
elementType
pc
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
...
...
@@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error {
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
String
()
cpMeta
.
Timestamp
=
time
.
Now
()
.
UnixNano
()
h
:=
md5
.
New
()
cpMeta
.
M
d5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMeta
.
M
D5
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMetajson
,
_
:=
json
.
Marshal
(
cpMeta
)
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
)
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
*
time
.
Second
)
if
err
!=
nil
{
return
err
}
...
...
@@ -219,8 +257,12 @@ func (s *Service) doCheckpoint() error {
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
if
err
!=
nil
{
log
.
Infof
(
"Removing checkpoint %s failed"
,
cpMeta
.
UUID
)
}
else
{
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
defer
f
.
Close
()
if
err
!=
nil
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录