Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
15f021a9
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
15f021a9
编写于
7月 11, 2017
作者:
D
dzhwinter
提交者:
GitHub
7月 11, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2716 from dzhwinter/save_state
Pserver Save state
上级
8c615e8f
e8296ff2
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
146 addition
and
31 deletion
+146
-31
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+6
-2
go/pserver/etcd_client.go
go/pserver/etcd_client.go
+13
-0
go/pserver/optimizer.go
go/pserver/optimizer.go
+15
-3
go/pserver/optimizer_test.go
go/pserver/optimizer_test.go
+1
-1
go/pserver/service.go
go/pserver/service.go
+99
-11
go/pserver/service_test.go
go/pserver/service_test.go
+12
-14
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
15f021a9
...
...
@@ -20,6 +20,8 @@ func main() {
"comma separated endpoint string for pserver to connect to etcd"
)
etcdTimeout
:=
flag
.
Int
(
"etcd-timeout"
,
5
,
"timeout for etcd calls"
)
numPservers
:=
flag
.
Int
(
"num-pservers"
,
1
,
"total pserver count in a training job"
)
checkpointPath
:=
flag
.
String
(
"checkpoint-path"
,
"/checkpoints/"
,
"save checkpoint path"
)
checkpointInterval
:=
flag
.
Int
(
"checkpoint-interval"
,
600
,
"save checkpoint per interval seconds"
)
logLevel
:=
flag
.
String
(
"log-level"
,
"info"
,
"log level, possible values: debug, info, warning, error, fatal, panic"
)
flag
.
Parse
()
...
...
@@ -31,18 +33,20 @@ func main() {
log
.
SetLevel
(
level
)
var
idx
int
var
cp
pserver
.
Checkpoint
var
e
*
pserver
.
EtcdClient
if
*
index
>=
0
{
idx
=
*
index
}
else
{
timeout
:=
time
.
Second
*
time
.
Duration
((
*
etcdTimeout
))
e
:
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
e
=
pserver
.
NewEtcdClient
(
*
etcdEndpoint
,
*
numPservers
,
timeout
)
idx
,
err
=
e
.
Register
()
if
err
!=
nil
{
panic
(
err
)
}
}
s
,
err
:=
pserver
.
NewService
(
idx
)
s
,
err
:=
pserver
.
NewService
(
idx
,
*
checkpointInterval
,
*
checkpointPath
,
e
,
cp
)
if
err
!=
nil
{
panic
(
err
)
}
...
...
go/pserver/etcd_client.go
浏览文件 @
15f021a9
...
...
@@ -18,6 +18,8 @@ const (
PsDesired
=
"/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath
=
"/ps/"
// PsCheckpoint is the etcd path for store checkpoints information
PsCheckpoint
=
"/checkpoints/"
)
// EtcdClient is the etcd client that the pserver uses for fault
...
...
@@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
return
idx
,
nil
}
// PutKey put into etcd with value by key specified
func
(
e
*
EtcdClient
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
int
)
error
{
ctx
,
cancel
:=
context
.
WithTimeout
(
context
.
Background
(),
time
.
Second
*
time
.
Duration
(
timeout
))
_
,
err
:=
e
.
etcdClient
.
Put
(
ctx
,
key
,
string
(
value
))
cancel
()
if
err
!=
nil
{
return
err
}
return
nil
}
go/pserver/optimizer.go
浏览文件 @
15f021a9
...
...
@@ -35,22 +35,28 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
}
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
)
*
optimizer
{
func
newOptimizer
(
paramWithConfigs
ParameterWithConfig
,
State
[]
byte
)
*
optimizer
{
o
:=
&
optimizer
{}
o
.
elementType
=
paramWithConfigs
.
Param
.
ElementType
p
:=
paramWithConfigs
.
Param
c
:=
paramWithConfigs
.
Config
s
:=
State
log
.
WithFields
(
log
.
Fields
{
"ElementType"
:
p
.
ElementType
,
"ParamSize"
:
len
(
p
.
Content
),
"ConfigSize"
:
len
(
c
),
"StateSize"
:
len
(
s
),
})
.
Info
(
"New Optimizer Created with config:"
)
var
cbuffer
unsafe
.
Pointer
cbuffer
=
C
.
malloc
(
C
.
size_t
(
len
(
p
.
Content
)))
C
.
memcpy
(
cbuffer
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
C
.
size_t
(
len
(
p
.
Content
)))
var
cstate
unsafe
.
Pointer
if
len
(
s
)
!=
0
{
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
o
.
opt
=
C
.
paddle_create_optimizer
((
*
C
.
uchar
)(
&
c
[
0
]),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
len
(
p
.
Content
)
/
C
.
sizeof_float
),
(
*
C
.
char
)(
nullPtr
),
0
)
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
C
.
int
(
len
(
p
.
Content
)
/
C
.
sizeof_float
),
(
*
C
.
char
)(
cstate
),
C
.
int
(
len
(
s
)))
return
o
}
...
...
@@ -60,6 +66,12 @@ func (o *optimizer) GetWeights() []byte {
return
cArrayToSlice
(
buffer
,
int
(
bufferLen
)
*
C
.
sizeof_float
)
}
func
(
o
*
optimizer
)
GetStates
()
[]
byte
{
var
cbuffer
*
C
.
char
cbuffer_len
:=
C
.
paddle_optimizer_get_state
(
o
.
opt
,
&
cbuffer
)
return
cArrayToSlice
(
unsafe
.
Pointer
(
cbuffer
),
int
(
cbuffer_len
))
}
func
(
o
*
optimizer
)
UpdateParameter
(
g
Gradient
)
error
{
if
o
.
elementType
!=
g
.
ElementType
{
return
fmt
.
Errorf
(
"Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v"
,
g
.
Name
,
o
.
elementType
,
g
.
ElementType
)
...
...
go/pserver/optimizer_test.go
浏览文件 @
15f021a9
...
...
@@ -19,6 +19,6 @@ func TestOptimizerCreateRelease(t *testing.T) {
Param
:
p
,
Config
:
config
,
}
o
:=
newOptimizer
(
param
)
o
:=
newOptimizer
(
param
,
nil
)
o
.
Cleanup
()
}
go/pserver/service.go
浏览文件 @
15f021a9
package
pserver
import
(
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strconv"
"sync"
"time"
log
"github.com/sirupsen/logrus"
)
// ElementType is the type of elements of a Parameter.
...
...
@@ -39,6 +51,22 @@ type ParameterWithConfig struct {
Config
[]
byte
// parameter configuration in Proto Buffer format
}
// ParameterCheckpoint is Parameter and State checkpoint
type
ParameterCheckpoint
struct
{
ParamConfig
ParameterWithConfig
State
[]
byte
}
// checkpoint signature
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Md5sum
string
`json:"md5sum"`
Timestamp
string
`json:"timestamp"`
}
// Checkpoint is the pserver shard persist in file
type
Checkpoint
[]
ParameterCheckpoint
// Gradient is the gradient of the parameter.
type
Gradient
Parameter
...
...
@@ -46,19 +74,32 @@ type Gradient Parameter
type
Service
struct
{
initialized
chan
struct
{}
idx
int
checkpointInterval
time
.
Duration
checkpointPath
string
client
*
EtcdClient
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func
NewService
(
idx
int
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
seconds
int
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
idx
:
idx
,
checkpointInterval
:
time
.
Second
*
time
.
Duration
(
seconds
),
checkpointPath
:
path
,
client
:
client
,
}
s
.
optMap
=
make
(
map
[
string
]
*
optimizer
)
s
.
initialized
=
make
(
chan
struct
{})
if
cp
!=
nil
{
for
_
,
item
:=
range
cp
{
p
:=
item
.
ParamConfig
st
:=
item
.
State
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
st
)
}
}
return
s
,
nil
}
...
...
@@ -78,7 +119,7 @@ func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) er
// TODO(helin): check if paramWithConfigs.Param.Content is
// properly memory aligned, if not, make copy to a memory
// aligned region.
s
.
optMap
[
paramWithConfigs
.
Param
.
Name
]
=
newOptimizer
(
paramWithConfigs
)
s
.
optMap
[
paramWithConfigs
.
Param
.
Name
]
=
newOptimizer
(
paramWithConfigs
,
nil
)
return
nil
}
...
...
@@ -139,10 +180,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return
nil
}
//
Save tells the parameter server to save parameters.
func
(
s
*
Service
)
Save
(
path
string
,
dummy
*
int
)
error
{
//
pserver save checkpoint
func
(
s
*
Service
)
doCheckpoint
(
)
error
{
<-
s
.
initialized
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
cp
:=
make
([]
ParameterCheckpoint
,
0
,
len
(
s
.
optMap
))
index
:=
0
for
name
,
opt
:=
range
s
.
optMap
{
var
pc
ParameterCheckpoint
pc
.
ParamConfig
.
Param
.
Name
=
name
pc
.
ParamConfig
.
Param
.
ElementType
=
opt
.
elementType
pc
.
ParamConfig
.
Param
.
Content
=
opt
.
GetWeights
()
pc
.
State
=
opt
.
GetStates
()
cp
[
index
]
=
pc
index
++
}
var
buf
bytes
.
Buffer
encoder
:=
gob
.
NewEncoder
(
&
buf
)
err
:=
encoder
.
Encode
(
cp
)
if
err
!=
nil
{
return
err
}
// TODO
cpMeta
:=
checkpointMeta
{}
cpMeta
.
UUID
=
s
.
checkpointPath
+
strconv
.
Itoa
(
s
.
idx
)
cpMeta
.
Timestamp
=
time
.
Now
()
.
String
()
h
:=
md5
.
New
()
cpMeta
.
Md5sum
=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMetajson
,
_
:=
json
.
Marshal
(
cpMeta
)
err
=
s
.
client
.
PutKey
(
filepath
.
Join
(
PsCheckpoint
,
strconv
.
Itoa
(
s
.
idx
)),
cpMetajson
,
3
)
if
err
!=
nil
{
return
err
}
if
_
,
err
=
os
.
Stat
(
cpMeta
.
UUID
);
os
.
IsNotExist
(
err
)
{
log
.
Info
(
"checkpoint does not exists."
)
}
else
{
err
=
os
.
Remove
(
cpMeta
.
UUID
)
log
.
Infof
(
"checkpoint %s already exsits, removing "
,
cpMeta
.
UUID
)
}
f
,
err
:=
os
.
Create
(
cpMeta
.
UUID
)
defer
f
.
Close
()
if
err
!=
nil
{
return
err
}
writer
:=
bufio
.
NewWriter
(
f
)
_
,
err
=
writer
.
Write
(
buf
.
Bytes
())
writer
.
Flush
()
if
err
!=
nil
{
return
err
}
return
nil
}
go/pserver/service_test.go
浏览文件 @
15f021a9
...
...
@@ -15,7 +15,8 @@ const (
)
func
TestServiceFull
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -86,7 +87,8 @@ func TestServiceFull(t *testing.T) {
}
func
TestMultipleInit
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -102,7 +104,8 @@ func TestMultipleInit(t *testing.T) {
}
func
TestUninitialized
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
err
=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
t
.
FailNow
()
...
...
@@ -110,7 +113,8 @@ func TestUninitialized(t *testing.T) {
}
func
TestBlockUntilInitialized
(
t
*
testing
.
T
)
{
s
,
err
:=
pserver
.
NewService
(
0
)
var
cp
pserver
.
Checkpoint
s
,
err
:=
pserver
.
NewService
(
0
,
1
,
""
,
nil
,
cp
)
if
err
!=
nil
{
t
.
Error
(
err
)
}
...
...
@@ -128,16 +132,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch
<-
struct
{}{}
}()
wg
.
Add
(
1
)
go
func
()
{
err
:=
s
.
Save
(
""
,
nil
)
if
err
!=
nil
{
errCh
<-
err
}
wg
.
Done
()
ch
<-
struct
{}{}
}()
time
.
Sleep
(
50
*
time
.
Millisecond
)
select
{
...
...
@@ -170,3 +164,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Wait
()
}
func
TestCheckpointSpeed
(
t
*
testing
.
T
)
{
//TODO(zhihong): test speed
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录