Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
40295b9e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
40295b9e
编写于
7月 07, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"fix pserver saving etcd"
上级
bfc3b436
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
116 addition
and
81 deletion
+116
-81
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+4
-1
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+13
-0
go/pserver/optimizer.go
go/pserver/optimizer.go
+2
-2
go/pserver/service.go
go/pserver/service.go
+95
-75
go/pserver/service_test.go
go/pserver/service_test.go
+2
-3
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
40295b9e
...
@@ -20,6 +20,8 @@ func main() {
...
@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd"
)
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Int
(
"checkpoint-interval"
,
"10"
,
"save checkpoint per interval seconds"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
flag
.
Parse
()
...
@@ -31,6 +33,7 @@ func main() {
...
@@ -31,6 +33,7 @@ func main() {
log
.
SetLevel
(
level
)
log
.
SetLevel
(
level
)
var
idx
int
var
idx
int
var
cp
pserver
.
Checkpoint
if
*
index
>=
0
{
if
*
index
>=
0
{
idx
=
*
index
idx
=
*
index
}
else
{
}
else
{
...
@@ -42,7 +45,7 @@ func main() {
...
@@ -42,7 +45,7 @@ func main() {
}
}
}
}
s
,
err
:=
pserver
.
NewService
(
idx
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
...
...
go/pserver/etcd_client.go
浏览文件 @
40295b9e
...
@@ -18,6 +18,8 @@ const (
...
@@ -18,6 +18,8 @@ const (
PsDesired
=
"/ps_desired"
PsDesired
=
"/ps_desired"
// PsAddr is the base dir for pserver to store their addr
// PsAddr is the base dir for pserver to store their addr
PsPath
=
"/ps/"
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
)
)
// EtcdClient is the etcd client that the pserver uses for fault
// EtcdClient is the etcd client that the pserver uses for fault
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return
idx
,
nil
return
idx
,
nil
}
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
int
)
error
{
ctx
,
err
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
))
_
,
err
=
e
.
Put
(
ctx
,
key
,
value
)
cancel
()
if
err
!=
nil
{
return
err
}
return
nil
}
go/pserver/optimizer.go
浏览文件 @
40295b9e
...
@@ -35,12 +35,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
...
@@ -35,12 +35,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
}
}
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
)
*
optimizer
{
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
,
State
[]
byte
)
*
optimizer
{
o
:=
&
optimizer
{}
o
:=
&
optimizer
{}
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
p
:=
paramWithConfigs
.
Param
p
:=
paramWithConfigs
.
Param
c
:=
paramWithConfigs
.
Config
c
:=
paramWithConfigs
.
Config
s
:=
paramWithConfigs
.
State
s
:=
State
log
.
WithFields
(
log
.
Fields
{
log
.
WithFields
(
log
.
Fields
{
"ElementType"
:
p
.
ElementType
,
"ElementType"
:
p
.
ElementType
,
"ParamSize"
:
len
(
p
.
Content
),
"ParamSize"
:
len
(
p
.
Content
),
...
...
go/pserver/service.go
浏览文件 @
40295b9e
...
@@ -5,10 +5,11 @@ import (
...
@@ -5,10 +5,11 @@ import (
"bytes"
"bytes"
"crypto/md5"
"crypto/md5"
"encoding/gob"
"encoding/gob"
"encoding/
hex
"
"encoding/
json
"
"errors"
"errors"
"fmt"
"fmt"
"os"
"os"
"path/filepath"
"strconv"
"strconv"
"sync"
"sync"
"time"
"time"
...
@@ -26,10 +27,6 @@ const (
...
@@ -26,10 +27,6 @@ const (
Uninitialized
=
"pserver not fully initialized"
Uninitialized
=
"pserver not fully initialized"
)
)
const
(
checkpoint_path
=
"./checkpoints/"
)
// Supported element types
// Supported element types
const
(
const
(
Int32
ElementType
=
iota
Int32
ElementType
=
iota
...
@@ -51,49 +48,68 @@ type Parameter struct {
...
@@ -51,49 +48,68 @@ type Parameter struct {
type
ParameterWithConfig
struct
{
type
ParameterWithConfig
struct
{
Param
Parameter
Param
Parameter
Config
[]
byte
// parameter configuration in Proto Buffer format
Config
[]
byte
// parameter configuration in Proto Buffer format
State
[]
byte
// parameter training state
}
}
// Checkpoint of Parameter and State
type
parameterCheckPoint
struct
{
ParamConfig
ParameterWithConfig
State
[]
byte
}
// checkpoint signature
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Md5sum
string
`json:"md5sum"`
Timestamp
string
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type
Checkpoint
[]
parameterCheckPoint
// Gradient is the gradient of the parameter.
// Gradient is the gradient of the parameter.
type
Gradient
Parameter
// Service is the RPC service for pserver.
// Service is the RPC service for pserver.
type
Service
struct
{
type
Service
struct
{
initialized
chan
struct
{}
initialized
chan
struct
{}
idx
int
idx
int
checkpointInterval
int
checkpointPath
string
client
*
EtcdClient
mu
sync
.
Mutex
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
optMap
map
[
string
]
*
optimizer
}
}
type
checkpoint
struct
{
// //serialize ParameterWithConfig to byte stream
Uuid
string
// func GetBytes(content ...interface{}) ([]byte, error) {
Md5sum
string
Timestamp
string
}
//serialize ParameterWithConfig to byte stream
// var buf bytes.Buffer
func
GetBytes
(
content
...
interface
{})
([]
byte
,
error
)
{
// encoder := gob.NewEncoder(&buf)
// err := encoder.Encode(content)
var
buf
bytes
.
Buffer
// if err != nil {
encoder
:=
gob
.
NewEncoder
(
&
buf
)
// return nil, err
err
:=
encoder
.
Encode
(
content
)
// }
if
err
!=
nil
{
// return buf.Bytes(), nil
return
nil
,
err
// }
}
return
buf
.
Bytes
(),
nil
}
// NewService creates a new service, will bypass etcd registration if no
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
// endpoints specified.
func
NewService
(
idx
int
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
seconds
int
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
s
:=
&
Service
{
idx
:
idx
,
idx
:
idx
,
checkpointInterval
:
time
.
Second
*
time
.
Duration
(
seconds
),
checkpointPath
:
path
,
client
:
client
,
}
}
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
initialized
=
make
(
chan
struct
{})
s
.
initialized
=
make
(
chan
struct
{})
gob
.
Register
(
ParameterWithConfig
{})
gob
.
Register
(
checkpoint
{})
if
cp
!=
nil
{
for
_
,
item
:=
range
cp
{
p
:=
item
.
ParamConfig
st
:=
item
.
State
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
st
)
}
}
return
s
,
nil
return
s
,
nil
}
}
...
@@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
...
@@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return
nil
return
nil
}
}
// Save tells the parameter server to save parameters.
// pserver save checkpoint
func
(
s
*
Service
)
Save
(
path
string
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
doCheckpoint
()
error
{
//FIXME: checkpoint is only used by pserver
// and has a constant path of */checkpoints/{pserver_idx}*
<-
s
.
initialized
<-
s
.
initialized
s
.
mu
.
Lock
()
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
defer
s
.
mu
.
Unlock
()
var
paramWithConfig
ParameterWithConfig
cp
:=
make
([]
parameterCheckPoint
,
0
,
len
(
s
.
optMap
))
index
:=
0
for
name
,
opt
:=
range
s
.
optMap
{
for
name
,
opt
:=
range
s
.
optMap
{
paramWithConfig
.
Param
.
Name
=
name
var
pc
parameterCheckPoint
paramWithConfig
.
Param
.
ElementType
=
opt
.
elementType
pc
.
ParamConfig
.
Param
.
Name
=
name
paramWithConfig
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
ParamConfig
.
Param
.
ElementType
=
opt
.
elementType
paramWithConfig
.
State
=
opt
.
GetStates
()
pc
.
ParamConfig
.
Param
.
Content
=
opt
.
GetWeights
()
content
,
err
:=
GetBytes
(
paramWithConfig
)
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
}
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
:=
encoder
.
Encode
(
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
err
}
}
ck
:=
checkpoint
{}
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
String
()
h
:=
md5
.
New
()
h
:=
md5
.
New
()
ck
.
Md5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
cpMeta
.
Md5sum
=
h
.
Sum
(
buf
.
Bytes
(
))
ck
.
Timestamp
=
time
.
Now
()
.
String
()
ck
.
Uuid
=
checkpoint_path
+
strconv
.
Itoa
(
s
.
idx
)
cpMetajson
,
err
:=
json
.
Marshal
(
cpMeta
)
ckbytes
,
err
:=
GetBytes
(
ck
)
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
)
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorln
(
err
)
return
err
}
}
// TODO: according design doc, need to save Uuid to etcd in json format
if
_
,
err
=
os
.
Stat
(
cpMeta
.
UUID
);
os
.
IsNotExist
(
err
)
{
// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx}
log
.
Info
(
"checkpoint does not exists."
)
log
.
Infof
(
"parameter checkpoint %s"
,
ckbytes
)
if
_
,
err
=
os
.
Stat
(
ck
.
Uuid
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint not exists."
)
}
else
{
}
else
{
err
=
os
.
Remove
(
ck
.
Uuid
)
err
=
os
.
Remove
(
cpMeta
.
UUID
)
log
.
Infof
(
"remove %s"
,
ck
.
Uuid
)
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
}
f
,
err
:=
os
.
Create
(
ck
.
Uuid
)
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
defer
f
.
Close
()
defer
f
.
Close
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorln
(
err
)
log
.
Errorln
(
err
)
}
}
writer
:=
bufio
.
NewWriter
(
f
)
writer
:=
bufio
.
NewWriter
(
f
)
_
,
err
=
writer
.
Write
(
content
)
_
,
err
=
writer
.
Write
(
buf
.
Bytes
()
)
writer
.
Flush
()
writer
.
Flush
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Errorln
(
err
)
log
.
Errorln
(
err
)
}
}
}
return
nil
return
nil
}
}
go/pserver/service_test.go
浏览文件 @
40295b9e
...
@@ -15,7 +15,8 @@ const (
...
@@ -15,7 +15,8 @@ const (
)
)
func
TestServiceFull
(
t
*
testing
.
T
)
{
func
TestServiceFull
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
Error
(
err
)
t
.
Error
(
err
)
}
}
...
@@ -83,8 +84,6 @@ func TestServiceFull(t *testing.T) {
...
@@ -83,8 +84,6 @@ func TestServiceFull(t *testing.T) {
if
!
reflect
.
DeepEqual
(
param1
,
p
)
{
if
!
reflect
.
DeepEqual
(
param1
,
p
)
{
t
.
FailNow
()
t
.
FailNow
()
}
}
var
dummy
int
s
.
Save
(
""
,
&
dummy
)
}
}
func
TestMultipleInit
(
t
*
testing
.
T
)
{
func
TestMultipleInit
(
t
*
testing
.
T
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录