Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b1cbdf03
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1cbdf03
编写于
10月 26, 2017
作者:
H
helinwang
提交者:
GitHub
10月 26, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5102 from helinwang/checkpoint
Fix pserver checkpoint
上级
66476fc7
00e2dcf3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
126 addition
and
34 deletion
+126
-34
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+2
-2
go/pserver/optimizer.go
go/pserver/optimizer.go
+7
-1
go/pserver/service.go
go/pserver/service.go
+31
-27
go/pserver/service_internal_test.go
go/pserver/service_internal_test.go
+86
-0
go/pserver/service_test.go
go/pserver/service_test.go
+0
-4
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
b1cbdf03
...
@@ -67,7 +67,7 @@ func main() {
...
@@ -67,7 +67,7 @@ func main() {
cp
,
err
=
pserver
.
LoadCheckpoint
(
e
,
idx
)
cp
,
err
=
pserver
.
LoadCheckpoint
(
e
,
idx
)
if
err
!=
nil
{
if
err
!=
nil
{
if
err
==
pserver
.
ErrCheckpointNotFound
{
if
err
==
pserver
.
ErrCheckpointNotFound
{
log
.
Info
(
"
Could not find the pserver checkpoint."
)
log
.
Info
(
"
load checkpoint error"
,
"error"
,
err
)
}
else
{
}
else
{
panic
(
err
)
panic
(
err
)
}
}
...
@@ -99,7 +99,7 @@ func main() {
...
@@ -99,7 +99,7 @@ func main() {
candy
.
Must
(
err
)
candy
.
Must
(
err
)
go
func
()
{
go
func
()
{
log
.
Info
(
"s
tart
ing pserver"
,
log
.
Ctx
{
"port"
:
*
port
})
log
.
Info
(
"s
erv
ing pserver"
,
log
.
Ctx
{
"port"
:
*
port
})
err
=
http
.
Serve
(
l
,
nil
)
err
=
http
.
Serve
(
l
,
nil
)
candy
.
Must
(
err
)
candy
.
Must
(
err
)
}()
}()
...
...
go/pserver/optimizer.go
浏览文件 @
b1cbdf03
...
@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
...
@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
}
var
cptr
(
*
C
.
uchar
)
if
len
(
c
)
>
0
{
cptr
=
(
*
C
.
uchar
)(
&
c
[
0
])
}
else
{
log
.
Error
(
"empty config"
,
"param name"
,
paramWithConfigs
.
Param
.
Name
)
}
o
.
config
=
c
o
.
config
=
c
o
.
opt
=
C
.
paddle_create_optimizer
(
o
.
opt
=
C
.
paddle_create_optimizer
(
(
*
C
.
uchar
)(
&
c
[
0
])
,
cptr
,
C
.
int
(
len
(
c
)),
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
cbuffer
,
...
...
go/pserver/service.go
浏览文件 @
b1cbdf03
...
@@ -17,12 +17,11 @@ package pserver
...
@@ -17,12 +17,11 @@ package pserver
import
(
import
(
"bufio"
"bufio"
"bytes"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/gob"
"encoding/hex"
"encoding/json"
"encoding/json"
"errors"
"errors"
"fmt"
"fmt"
"hash/crc32"
"io/ioutil"
"io/ioutil"
"os"
"os"
"path"
"path"
...
@@ -40,7 +39,7 @@ type ElementType int
...
@@ -40,7 +39,7 @@ type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
// not be found.
var
ErrCheckpointNotFound
=
errors
.
New
(
"checkpoint not found"
)
var
ErrCheckpointNotFound
=
errors
.
New
(
"checkpoint not found
in etcd
"
)
// RPC error message.
// RPC error message.
const
(
const
(
...
@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
...
@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type
checkpointMeta
struct
{
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
UUID
string
`json:"uuid"`
Path
string
`json:"path"`
Path
string
`json:"path"`
MD5
string
`json:"md5
"`
CRC32
uint32
`json:"crc32
"`
Timestamp
int64
`json:"timestamp"`
Timestamp
int64
`json:"timestamp"`
}
}
...
@@ -92,7 +91,7 @@ type Service struct {
...
@@ -92,7 +91,7 @@ type Service struct {
idx
int
idx
int
checkpointInterval
time
.
Duration
checkpointInterval
time
.
Duration
checkpointPath
string
checkpointPath
string
client
*
EtcdClient
client
KVStore
mu
sync
.
Mutex
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
optMap
map
[
string
]
*
optimizer
...
@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
...
@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State
[]
byte
State
[]
byte
}
}
func
loadMeta
(
e
*
EtcdClient
,
idx
int
)
(
meta
checkpointMeta
,
err
error
)
{
type
KVStore
interface
{
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
,
withLease
bool
)
error
}
func
loadMeta
(
e
KVStore
,
idx
int
)
(
meta
checkpointMeta
,
err
error
)
{
v
,
err
:=
e
.
GetKey
(
PsCheckpoint
+
strconv
.
Itoa
(
idx
),
3
*
time
.
Second
)
v
,
err
:=
e
.
GetKey
(
PsCheckpoint
+
strconv
.
Itoa
(
idx
),
3
*
time
.
Second
)
if
err
!=
nil
{
if
err
!=
nil
{
return
return
...
@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
...
@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}
}
// LoadCheckpoint loads checkpoint from file.
// LoadCheckpoint loads checkpoint from file.
func
LoadCheckpoint
(
e
*
EtcdClient
,
idx
int
)
(
Checkpoint
,
error
)
{
func
LoadCheckpoint
(
e
KVStore
,
idx
int
)
(
Checkpoint
,
error
)
{
log
.
Info
(
"Loading checkpoint"
,
"pserver index"
,
idx
)
log
.
Info
(
"Loading checkpoint"
,
"pserver index"
,
idx
)
defer
traceTime
(
time
.
Now
(),
"load checkpoint"
)
defer
traceTime
(
time
.
Now
(),
"load checkpoint"
)
...
@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
...
@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return
nil
,
err
return
nil
,
err
}
}
// TODO(helin): change MD5 to CRC since CRC is better for file
crc32
:=
crc32
.
ChecksumIEEE
(
content
)
// checksum in our use case (emphasize speed over security).
if
crc32
!=
cpMeta
.
CRC32
{
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
if
md5
!=
cpMeta
.
MD5
{
return
nil
,
errors
.
New
(
WrongChecksum
)
return
nil
,
errors
.
New
(
WrongChecksum
)
}
}
...
@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
...
@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if
err
=
dec
.
Decode
(
&
cp
);
err
!=
nil
{
if
err
=
dec
.
Decode
(
&
cp
);
err
!=
nil
{
return
nil
,
err
return
nil
,
err
}
}
return
cp
,
nil
return
cp
,
nil
}
}
// NewService creates a new service, will bypass etcd registration if no
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func
NewService
(
idx
int
,
interval
time
.
Duration
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
interval
time
.
Duration
,
path
string
,
client
KVStore
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
s
:=
&
Service
{
idx
:
idx
,
idx
:
idx
,
checkpointInterval
:
interval
,
checkpointInterval
:
interval
,
...
@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
...
@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
}
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
item
.
State
)
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
item
.
State
)
}
}
close
(
s
.
initialized
)
}
}
return
s
,
nil
return
s
,
nil
}
}
...
@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
...
@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for
range
t
{
for
range
t
{
err
:=
s
.
checkpoint
()
err
:=
s
.
checkpoint
()
if
err
!=
nil
{
if
err
!=
nil
{
log
.
Error
(
"
finish init params
error"
,
log
.
Ctx
{
"error"
:
err
})
log
.
Error
(
"
checkpoint
error"
,
log
.
Ctx
{
"error"
:
err
})
}
}
}
}
}()
}()
...
@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
...
@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter
.
Name
=
name
parameter
.
Name
=
name
parameter
.
ElementType
=
opt
.
elementType
parameter
.
ElementType
=
opt
.
elementType
parameter
.
Content
=
opt
.
GetWeights
()
parameter
.
Content
=
opt
.
GetWeights
()
log
.
Info
(
"sending parameter to the trainer"
,
"name"
,
parameter
.
Name
,
"size"
,
len
(
parameter
.
Content
),
"type"
,
parameter
.
ElementType
)
log
.
Info
(
"sending parameter to the trainer"
,
"name"
,
parameter
.
Name
,
"size"
,
len
(
parameter
.
Content
),
"type"
,
parameter
.
ElementType
)
return
nil
return
nil
}
}
...
@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
...
@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
oldMeta
,
err
:=
loadMeta
(
s
.
client
,
s
.
idx
)
oldMeta
,
err
:=
loadMeta
(
s
.
client
,
s
.
idx
)
if
err
==
ErrCheckpointNotFound
{
if
err
==
ErrCheckpointNotFound
{
log
.
Info
(
"
Do not have existing checkpoint.
"
)
log
.
Info
(
"
old meta not found, skip removing old meta
"
)
err
=
nil
err
=
nil
}
else
if
err
==
nil
{
log
.
Info
(
"removing old meta"
)
if
oldMeta
.
Path
!=
""
{
rmErr
:=
os
.
Remove
(
oldMeta
.
Path
)
if
rmErr
!=
nil
{
// log error, but still treat checkpoint as
// successful.
log
.
Error
(
"remove old meta file error"
,
log
.
Ctx
{
"error"
:
rmErr
})
}
}
}
}
if
err
!=
nil
{
if
err
!=
nil
{
return
return
}
}
h
:=
md5
.
New
()
crc32
:=
crc32
.
ChecksumIEEE
(
buf
.
Bytes
())
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
cpMeta
:=
checkpointMeta
{
cpMeta
:=
checkpointMeta
{
UUID
:
id
,
UUID
:
id
,
Timestamp
:
time
.
Now
()
.
UnixNano
(),
Timestamp
:
time
.
Now
()
.
UnixNano
(),
MD5
:
md5
,
CRC32
:
crc32
,
Path
:
p
,
Path
:
p
,
}
}
...
@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
...
@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
return
}
}
if
oldMeta
.
Path
!=
""
{
rmErr
:=
os
.
Remove
(
oldMeta
.
Path
)
if
rmErr
!=
nil
{
// log error, but still treat checkpoint as
// successful.
log
.
Error
(
"remove old meta file error"
,
log
.
Ctx
{
"error"
:
rmErr
})
}
}
return
return
}
}
go/pserver/service_internal_test.go
0 → 100644
浏览文件 @
b1cbdf03
package
pserver
import
(
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const
testDir
=
"./test_data"
type
myKV
struct
{
m
map
[
string
][]
byte
}
func
(
m
*
myKV
)
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
if
m
.
m
==
nil
{
m
.
m
=
make
(
map
[
string
][]
byte
)
}
return
m
.
m
[
key
],
nil
}
func
(
m
*
myKV
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
,
withLease
bool
)
error
{
if
m
.
m
==
nil
{
m
.
m
=
make
(
map
[
string
][]
byte
)
}
m
.
m
[
key
]
=
value
return
nil
}
func
TestCheckpoint
(
t
*
testing
.
T
)
{
kv
:=
&
myKV
{}
s
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
nil
)
assert
.
Nil
(
t
,
err
)
err
=
s
.
checkpoint
()
assert
.
Nil
(
t
,
err
)
_
,
err
=
LoadCheckpoint
(
kv
,
0
)
assert
.
Nil
(
t
,
err
)
}
func
float32ToByte
(
f
float32
)
[]
byte
{
var
buf
bytes
.
Buffer
err
:=
binary
.
Write
(
&
buf
,
binary
.
LittleEndian
,
f
)
if
err
!=
nil
{
fmt
.
Println
(
"binary.Write failed:"
,
err
)
}
return
buf
.
Bytes
()
}
func
TestCheckpointWithData
(
t
*
testing
.
T
)
{
kv
:=
&
myKV
{}
s
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
nil
)
assert
.
Nil
(
t
,
err
)
var
content
[]
byte
for
i
:=
0
;
i
<
50000
;
i
++
{
content
=
append
(
content
,
float32ToByte
(
float32
(
i
))
...
)
}
p1
:=
Parameter
{
Name
:
"p1"
,
ElementType
:
1
,
Content
:
content
}
err
=
s
.
InitParam
(
ParameterWithConfig
{
Param
:
p1
},
nil
)
assert
.
Nil
(
t
,
err
)
err
=
s
.
FinishInitParams
(
0
,
nil
)
assert
.
Nil
(
t
,
err
)
var
p2
Parameter
err
=
s
.
GetParam
(
p1
.
Name
,
&
p2
)
assert
.
Nil
(
t
,
err
)
assert
.
Equal
(
t
,
p1
,
p2
)
err
=
s
.
checkpoint
()
assert
.
Nil
(
t
,
err
)
cp
,
err
:=
LoadCheckpoint
(
kv
,
0
)
assert
.
Nil
(
t
,
err
)
s1
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
cp
)
assert
.
Nil
(
t
,
err
)
var
p3
Parameter
err
=
s1
.
GetParam
(
p1
.
Name
,
&
p3
)
assert
.
Nil
(
t
,
err
)
assert
.
Equal
(
t
,
p1
,
p3
)
}
go/pserver/service_test.go
浏览文件 @
b1cbdf03
...
@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Wait
()
wg
.
Wait
()
}
}
func
TestCheckpointSpeed
(
t
*
testing
.
T
)
{
//TODO(zhihong): test speed
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录