Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
40295b9e
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
40295b9e
编写于
7月 07, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"fix pserver saving etcd"
上级
bfc3b436
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
116 addition
and
81 deletion
+116
-81
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+4
-1
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+13
-0
go/pserver/optimizer.go
go/pserver/optimizer.go
+2
-2
go/pserver/service.go
go/pserver/service.go
+95
-75
go/pserver/service_test.go
go/pserver/service_test.go
+2
-3
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
40295b9e
...
...
@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Int
(
"checkpoint-interval"
,
"10"
,
"save checkpoint per interval seconds"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
...
...
@@ -31,6 +33,7 @@ func main() {
log
.
SetLevel
(
level
)
var
idx
int
var
cp
pserver
.
Checkpoint
if
*
index
>=
0
{
idx
=
*
index
}
else
{
...
...
@@ -42,7 +45,7 @@ func main() {
}
}
s
,
err
:=
pserver
.
NewService
(
idx
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
if
err
!=
nil
{
panic
(
err
)
}
...
...
go/pserver/etcd_client.go
浏览文件 @
40295b9e
...
...
@@ -18,6 +18,8 @@ const (
PsDesired
=
"/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
)
// EtcdClient is the etcd client that the pserver uses for fault
...
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return
idx
,
nil
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
int
)
error
{
ctx
,
err
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
))
_
,
err
=
e
.
Put
(
ctx
,
key
,
value
)
cancel
()
if
err
!=
nil
{
return
err
}
return
nil
}
go/pserver/optimizer.go
浏览文件 @
40295b9e
...
...
@@ -35,12 +35,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
}
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
)
*
optimizer
{
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
,
State
[]
byte
)
*
optimizer
{
o
:=
&
optimizer
{}
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
p
:=
paramWithConfigs
.
Param
c
:=
paramWithConfigs
.
Config
s
:=
paramWithConfigs
.
State
s
:=
State
log
.
WithFields
(
log
.
Fields
{
"ElementType"
:
p
.
ElementType
,
"ParamSize"
:
len
(
p
.
Content
),
...
...
go/pserver/service.go
浏览文件 @
40295b9e
...
...
@@ -5,10 +5,11 @@ import (
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/
hex
"
"encoding/
json
"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"time"
...
...
@@ -26,10 +27,6 @@ const (
Uninitialized
=
"pserver not fully initialized"
)
const
(
checkpoint_path
=
"./checkpoints/"
)
// Supported element types
const
(
Int32
ElementType
=
iota
...
...
@@ -51,49 +48,68 @@ type Parameter struct {
type
ParameterWithConfig
struct
{
Param
Parameter
Config
[]
byte
// parameter configuration in Proto Buffer format
State
[]
byte
// parameter training state
}
// Checkpoint of Parameter and State
type
parameterCheckPoint
struct
{
ParamConfig
ParameterWithConfig
State
[]
byte
}
// checkpoint signature
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Md5sum
string
`json:"md5sum"`
Timestamp
string
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type
Checkpoint
[]
parameterCheckPoint
// Gradient is the gradient of the parameter.
type
Gradient
Parameter
// Service is the RPC service for pserver.
type
Service
struct
{
initialized
chan
struct
{}
idx
int
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
initialized
chan
struct
{}
idx
int
checkpointInterval
int
checkpointPath
string
client
*
EtcdClient
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
}
type
checkpoint
struct
{
Uuid
string
Md5sum
string
Timestamp
string
}
// //serialize ParameterWithConfig to byte stream
// func GetBytes(content ...interface{}) ([]byte, error) {
//serialize ParameterWithConfig to byte stream
func
GetBytes
(
content
...
interface
{})
([]
byte
,
error
)
{
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
:=
encoder
.
Encode
(
content
)
if
err
!=
nil
{
return
nil
,
err
}
return
buf
.
Bytes
(),
nil
}
// var buf bytes.Buffer
// encoder := gob.NewEncoder(&buf)
// err := encoder.Encode(content)
// if err != nil {
// return nil, err
// }
// return buf.Bytes(), nil
// }
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func
NewService
(
idx
int
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
seconds
int
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
idx
:
idx
,
idx
:
idx
,
checkpointInterval
:
time
.
Second
*
time
.
Duration
(
seconds
),
checkpointPath
:
path
,
client
:
client
,
}
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
initialized
=
make
(
chan
struct
{})
gob
.
Register
(
ParameterWithConfig
{})
gob
.
Register
(
checkpoint
{})
if
cp
!=
nil
{
for
_
,
item
:=
range
cp
{
p
:=
item
.
ParamConfig
st
:=
item
.
State
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
st
)
}
}
return
s
,
nil
}
...
...
@@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return
nil
}
// Save tells the parameter server to save parameters.
func
(
s
*
Service
)
Save
(
path
string
,
dummy
*
int
)
error
{
//FIXME: checkpoint is only used by pserver
// and has a constant path of */checkpoints/{pserver_idx}*
// pserver save checkpoint
func
(
s
*
Service
)
doCheckpoint
()
error
{
<-
s
.
initialized
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
var
paramWithConfig
ParameterWithConfig
cp
:=
make
([]
parameterCheckPoint
,
0
,
len
(
s
.
optMap
))
index
:=
0
for
name
,
opt
:=
range
s
.
optMap
{
paramWithConfig
.
Param
.
Name
=
name
paramWithConfig
.
Param
.
ElementType
=
opt
.
elementType
paramWithConfig
.
Param
.
Content
=
opt
.
GetWeights
()
paramWithConfig
.
State
=
opt
.
GetStates
()
content
,
err
:=
GetBytes
(
paramWithConfig
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
ck
:=
checkpoint
{}
h
:=
md5
.
New
()
ck
.
Md5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
ck
.
Timestamp
=
time
.
Now
()
.
String
()
ck
.
Uuid
=
checkpoint_path
+
strconv
.
Itoa
(
s
.
idx
)
ckbytes
,
err
:=
GetBytes
(
ck
)
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
// TODO: according design doc, need to save Uuid to etcd in json format
// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx}
log
.
Infof
(
"parameter checkpoint %s"
,
ckbytes
)
if
_
,
err
=
os
.
Stat
(
ck
.
Uuid
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint not exists."
)
}
else
{
err
=
os
.
Remove
(
ck
.
Uuid
)
log
.
Infof
(
"remove %s"
,
ck
.
Uuid
)
}
f
,
err
:=
os
.
Create
(
ck
.
Uuid
)
defer
f
.
Close
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
writer
:=
bufio
.
NewWriter
(
f
)
_
,
err
=
writer
.
Write
(
content
)
writer
.
Flush
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
var
pc
parameterCheckPoint
pc
.
ParamConfig
.
Param
.
Name
=
name
pc
.
ParamConfig
.
Param
.
ElementType
=
opt
.
elementType
pc
.
ParamConfig
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
}
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
:=
encoder
.
Encode
(
cp
)
if
err
!=
nil
{
return
err
}
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
String
()
h
:=
md5
.
New
()
cpMeta
.
Md5sum
=
h
.
Sum
(
buf
.
Bytes
())
cpMetajson
,
err
:=
json
.
Marshal
(
cpMeta
)
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
)
if
err
!=
nil
{
return
err
}
if
_
,
err
=
os
.
Stat
(
cpMeta
.
UUID
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
defer
f
.
Close
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
writer
:=
bufio
.
NewWriter
(
f
)
_
,
err
=
writer
.
Write
(
buf
.
Bytes
())
writer
.
Flush
()
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
return
nil
}
go/pserver/service_test.go
浏览文件 @
40295b9e
...
...
@@ -15,7 +15,8 @@ const (
)
func
TestServiceFull
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -83,8 +84,6 @@ func TestServiceFull(t *testing.T) {
if
!
reflect
.
DeepEqual
(
param1
,
p
)
{
t
.
FailNow
()
}
var
dummy
int
s
.
Save
(
""
,
&
dummy
)
}
func
TestMultipleInit
(
t
*
testing
.
T
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录