Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
b1cbdf03
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
b1cbdf03
编写于
10月 26, 2017
作者:
H
helinwang
提交者:
GitHub
10月 26, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5102 from helinwang/checkpoint
Fix pserver checkpoint
上级
66476fc7
00e2dcf3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
126 addition
and
34 deletion
+126
-34
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+2
-2
go/pserver/optimizer.go
go/pserver/optimizer.go
+7
-1
go/pserver/service.go
go/pserver/service.go
+31
-27
go/pserver/service_internal_test.go
go/pserver/service_internal_test.go
+86
-0
go/pserver/service_test.go
go/pserver/service_test.go
+0
-4
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
b1cbdf03
...
@@ -67,7 +67,7 @@ func main() {
...
@@ -67,7 +67,7 @@ func main() {
cp
,
err
=
pserver
.
LoadCheckpoint
(
e
,
idx
)
cp
,
err
=
pserver
.
LoadCheckpoint
(
e
,
idx
)
if
err
!=
nil
{
if
err
!=
nil
{
if
err
==
pserver
.
ErrCheckpointNotFound
{
if
err
==
pserver
.
ErrCheckpointNotFound
{
log
.
Info
(
"
Could not find the pserver checkpoint."
)
log
.
Info
(
"
load checkpoint error"
,
"error"
,
err
)
}
else
{
}
else
{
panic
(
err
)
panic
(
err
)
}
}
...
@@ -99,7 +99,7 @@ func main() {
...
@@ -99,7 +99,7 @@ func main() {
candy
.
Must
(
err
)
candy
.
Must
(
err
)
go
func
()
{
go
func
()
{
log
.
Info
(
"s
tart
ing pserver"
,
log
.
Ctx
{
"port"
:
*
port
})
log
.
Info
(
"s
erv
ing pserver"
,
log
.
Ctx
{
"port"
:
*
port
})
err
=
http
.
Serve
(
l
,
nil
)
err
=
http
.
Serve
(
l
,
nil
)
candy
.
Must
(
err
)
candy
.
Must
(
err
)
}()
}()
...
...
go/pserver/optimizer.go
浏览文件 @
b1cbdf03
...
@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
...
@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
}
var
cptr
(
*
C
.
uchar
)
if
len
(
c
)
>
0
{
cptr
=
(
*
C
.
uchar
)(
&
c
[
0
])
}
else
{
log
.
Error
(
"empty config"
,
"param name"
,
paramWithConfigs
.
Param
.
Name
)
}
o
.
config
=
c
o
.
config
=
c
o
.
opt
=
C
.
paddle_create_optimizer
(
o
.
opt
=
C
.
paddle_create_optimizer
(
(
*
C
.
uchar
)(
&
c
[
0
])
,
cptr
,
C
.
int
(
len
(
c
)),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
cbuffer
,
...
...
go/pserver/service.go
浏览文件 @
b1cbdf03
...
@@ -17,12 +17,11 @@ package pserver
...
@@ -17,12 +17,11 @@ package pserver
import
(
import
(
"bufio"
"bufio"
"bytes"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/gob"
"encoding/hex"
"encoding/json"
"encoding/json"
"errors"
"errors"
"fmt"
"fmt"
"hash/crc32"
"io/ioutil"
"io/ioutil"
"os"
"os"
"path"
"path"
...
@@ -40,7 +39,7 @@ type ElementType int
...
@@ -40,7 +39,7 @@ type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
// not be found.
var
ErrCheckpointNotFound
=
errors
.
New
(
"checkpoint not found"
)
var
ErrCheckpointNotFound
=
errors
.
New
(
"checkpoint not found
in etcd
"
)
// RPC error message.
// RPC error message.
const
(
const
(
...
@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
...
@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type
checkpointMeta
struct
{
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
UUID
string
`json:"uuid"`
Path
string
`json:"path"`
Path
string
`json:"path"`
MD5
string
`json:"md5
"`
CRC32
uint32
`json:"crc32
"`
Timestamp
int64
`json:"timestamp"`
Timestamp
int64
`json:"timestamp"`
}
}
...
@@ -92,7 +91,7 @@ type Service struct {
...
@@ -92,7 +91,7 @@ type Service struct {
idx
int
idx
int
checkpointInterval
time
.
Duration
checkpointInterval
time
.
Duration
checkpointPath
string
checkpointPath
string
client
*
EtcdClient
client
KVStore
mu
sync
.
Mutex
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
optMap
map
[
string
]
*
optimizer
...
@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
...
@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State
[]
byte
State
[]
byte
}
}
func
loadMeta
(
e
*
EtcdClient
,
idx
int
)
(
meta
checkpointMeta
,
err
error
)
{
type
KVStore
interface
{
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
,
withLease
bool
)
error
}
func
loadMeta
(
e
KVStore
,
idx
int
)
(
meta
checkpointMeta
,
err
error
)
{
v
,
err
:=
e
.
GetKey
(
PsCheckpoint
+
strconv
.
Itoa
(
idx
),
3
*
time
.
Second
)
v
,
err
:=
e
.
GetKey
(
PsCheckpoint
+
strconv
.
Itoa
(
idx
),
3
*
time
.
Second
)
if
err
!=
nil
{
if
err
!=
nil
{
return
return
...
@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
...
@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}
}
// LoadCheckpoint loads checkpoint from file.
// LoadCheckpoint loads checkpoint from file.
func
LoadCheckpoint
(
e
*
EtcdClient
,
idx
int
)
(
Checkpoint
,
error
)
{
func
LoadCheckpoint
(
e
KVStore
,
idx
int
)
(
Checkpoint
,
error
)
{
log
.
Info
(
"Loading checkpoint"
,
"pserver index"
,
idx
)
log
.
Info
(
"Loading checkpoint"
,
"pserver index"
,
idx
)
defer
traceTime
(
time
.
Now
(),
"load checkpoint"
)
defer
traceTime
(
time
.
Now
(),
"load checkpoint"
)
...
@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
...
@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return
nil
,
err
return
nil
,
err
}
}
// TODO(helin): change MD5 to CRC since CRC is better for file
crc32
:=
crc32
.
ChecksumIEEE
(
content
)
// checksum in our use case (emphasize speed over security).
if
crc32
!=
cpMeta
.
CRC32
{
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
if
md5
!=
cpMeta
.
MD5
{
return
nil
,
errors
.
New
(
WrongChecksum
)
return
nil
,
errors
.
New
(
WrongChecksum
)
}
}
...
@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
...
@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if
err
=
dec
.
Decode
(
&
cp
);
err
!=
nil
{
if
err
=
dec
.
Decode
(
&
cp
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
return
cp
,
nil
return
cp
,
nil
}
}
// NewService creates a new service, will bypass etcd registration if no
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func
NewService
(
idx
int
,
interval
time
.
Duration
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
interval
time
.
Duration
,
path
string
,
client
KVStore
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
s
:=
&
Service
{
idx
:
idx
,
idx
:
idx
,
checkpointInterval
:
interval
,
checkpointInterval
:
interval
,
...
@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
...
@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
}
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
item
.
State
)
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
item
.
State
)
}
}
close
(
s
.
initialized
)
}
}
return
s
,
nil
return
s
,
nil
}
}
...
@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
...
@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for
range
t
{
for
range
t
{
err
:=
s
.
checkpoint
()
err
:=
s
.
checkpoint
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Error
(
"
finish init params
error"
,
log
.
Ctx
{
"error"
:
err
})
log
.
Error
(
"
checkpoint
error"
,
log
.
Ctx
{
"error"
:
err
})
}
}
}
}
}()
}()
...
@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
...
@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter
.
Name
=
name
parameter
.
Name
=
name
parameter
.
ElementType
=
opt
.
elementType
parameter
.
ElementType
=
opt
.
elementType
parameter
.
Content
=
opt
.
GetWeights
()
parameter
.
Content
=
opt
.
GetWeights
()
log
.
Info
(
"sending parameter to the trainer"
,
"name"
,
parameter
.
Name
,
"size"
,
len
(
parameter
.
Content
),
"type"
,
parameter
.
ElementType
)
log
.
Info
(
"sending parameter to the trainer"
,
"name"
,
parameter
.
Name
,
"size"
,
len
(
parameter
.
Content
),
"type"
,
parameter
.
ElementType
)
return
nil
return
nil
}
}
...
@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
...
@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
oldMeta
,
err
:=
loadMeta
(
s
.
client
,
s
.
idx
)
oldMeta
,
err
:=
loadMeta
(
s
.
client
,
s
.
idx
)
if
err
==
ErrCheckpointNotFound
{
if
err
==
ErrCheckpointNotFound
{
log
.
Info
(
"
Do not have existing checkpoint.
"
)
log
.
Info
(
"
old meta not found, skip removing old meta
"
)
err
=
nil
err
=
nil
}
else
if
err
==
nil
{
log
.
Info
(
"removing old meta"
)
if
oldMeta
.
Path
!=
""
{
rmErr
:=
os
.
Remove
(
oldMeta
.
Path
)
if
rmErr
!=
nil
{
// log error, but still treat checkpoint as
// successful.
log
.
Error
(
"remove old meta file error"
,
log
.
Ctx
{
"error"
:
rmErr
})
}
}
}
}
if
err
!=
nil
{
if
err
!=
nil
{
return
return
}
}
h
:=
md5
.
New
()
crc32
:=
crc32
.
ChecksumIEEE
(
buf
.
Bytes
())
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMeta
:=
checkpointMeta
{
cpMeta
:=
checkpointMeta
{
UUID
:
id
,
UUID
:
id
,
Timestamp
:
time
.
Now
()
.
UnixNano
(),
Timestamp
:
time
.
Now
()
.
UnixNano
(),
MD5
:
md5
,
CRC32
:
crc32
,
Path
:
p
,
Path
:
p
,
}
}
...
@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
...
@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
return
}
}
if
oldMeta
.
Path
!=
""
{
rmErr
:=
os
.
Remove
(
oldMeta
.
Path
)
if
rmErr
!=
nil
{
// log error, but still treat checkpoint as
// successful.
log
.
Error
(
"remove old meta file error"
,
log
.
Ctx
{
"error"
:
rmErr
})
}
}
return
return
}
}
go/pserver/service_internal_test.go
0 → 100644
浏览文件 @
b1cbdf03
package
pserver
import
(
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const
testDir
=
"./test_data"
type
myKV
struct
{
m
map
[
string
][]
byte
}
func
(
m
*
myKV
)
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
if
m
.
m
==
nil
{
m
.
m
=
make
(
map
[
string
][]
byte
)
}
return
m
.
m
[
key
],
nil
}
func
(
m
*
myKV
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
,
withLease
bool
)
error
{
if
m
.
m
==
nil
{
m
.
m
=
make
(
map
[
string
][]
byte
)
}
m
.
m
[
key
]
=
value
return
nil
}
func
TestCheckpoint
(
t
*
testing
.
T
)
{
kv
:=
&
myKV
{}
s
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
nil
)
assert
.
Nil
(
t
,
err
)
err
=
s
.
checkpoint
()
assert
.
Nil
(
t
,
err
)
_
,
err
=
LoadCheckpoint
(
kv
,
0
)
assert
.
Nil
(
t
,
err
)
}
func
float32ToByte
(
f
float32
)
[]
byte
{
var
buf
bytes
.
Buffer
err
:=
binary
.
Write
(
&
buf
,
binary
.
LittleEndian
,
f
)
if
err
!=
nil
{
fmt
.
Println
(
"binary.Write failed:"
,
err
)
}
return
buf
.
Bytes
()
}
func
TestCheckpointWithData
(
t
*
testing
.
T
)
{
kv
:=
&
myKV
{}
s
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
nil
)
assert
.
Nil
(
t
,
err
)
var
content
[]
byte
for
i
:=
0
;
i
<
50000
;
i
++
{
content
=
append
(
content
,
float32ToByte
(
float32
(
i
))
...
)
}
p1
:=
Parameter
{
Name
:
"p1"
,
ElementType
:
1
,
Content
:
content
}
err
=
s
.
InitParam
(
ParameterWithConfig
{
Param
:
p1
},
nil
)
assert
.
Nil
(
t
,
err
)
err
=
s
.
FinishInitParams
(
0
,
nil
)
assert
.
Nil
(
t
,
err
)
var
p2
Parameter
err
=
s
.
GetParam
(
p1
.
Name
,
&
p2
)
assert
.
Nil
(
t
,
err
)
assert
.
Equal
(
t
,
p1
,
p2
)
err
=
s
.
checkpoint
()
assert
.
Nil
(
t
,
err
)
cp
,
err
:=
LoadCheckpoint
(
kv
,
0
)
assert
.
Nil
(
t
,
err
)
s1
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
cp
)
assert
.
Nil
(
t
,
err
)
var
p3
Parameter
err
=
s1
.
GetParam
(
p1
.
Name
,
&
p3
)
assert
.
Nil
(
t
,
err
)
assert
.
Equal
(
t
,
p1
,
p3
)
}
go/pserver/service_test.go
浏览文件 @
b1cbdf03
...
@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Wait
()
wg
.
Wait
()
}
}
func
TestCheckpointSpeed
(
t
*
testing
.
T
)
{
//TODO(zhihong): test speed
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录