Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
b1cbdf03
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b1cbdf03
编写于
10月 26, 2017
作者:
H
helinwang
提交者:
GitHub
10月 26, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #5102 from helinwang/checkpoint
Fix pserver checkpoint
上级
66476fc7
00e2dcf3
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
126 addition
and
34 deletion
+126
-34
go/cmd/pserver/pserver.go
go/cmd/pserver/pserver.go
+2
-2
go/pserver/optimizer.go
go/pserver/optimizer.go
+7
-1
go/pserver/service.go
go/pserver/service.go
+31
-27
go/pserver/service_internal_test.go
go/pserver/service_internal_test.go
+86
-0
go/pserver/service_test.go
go/pserver/service_test.go
+0
-4
未找到文件。
go/cmd/pserver/pserver.go
浏览文件 @
b1cbdf03
...
...
@@ -67,7 +67,7 @@ func main() {
cp
,
err
=
pserver
.
LoadCheckpoint
(
e
,
idx
)
if
err
!=
nil
{
if
err
==
pserver
.
ErrCheckpointNotFound
{
log
.
Info
(
"
Could not find the pserver checkpoint."
)
log
.
Info
(
"
load checkpoint error"
,
"error"
,
err
)
}
else
{
panic
(
err
)
}
...
...
@@ -99,7 +99,7 @@ func main() {
candy
.
Must
(
err
)
go
func
()
{
log
.
Info
(
"s
tart
ing pserver"
,
log
.
Ctx
{
"port"
:
*
port
})
log
.
Info
(
"s
erv
ing pserver"
,
log
.
Ctx
{
"port"
:
*
port
})
err
=
http
.
Serve
(
l
,
nil
)
candy
.
Must
(
err
)
}()
...
...
go/pserver/optimizer.go
浏览文件 @
b1cbdf03
...
...
@@ -71,9 +71,15 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate
=
unsafe
.
Pointer
(
&
s
[
0
])
}
var
cptr
(
*
C
.
uchar
)
if
len
(
c
)
>
0
{
cptr
=
(
*
C
.
uchar
)(
&
c
[
0
])
}
else
{
log
.
Error
(
"empty config"
,
"param name"
,
paramWithConfigs
.
Param
.
Name
)
}
o
.
config
=
c
o
.
opt
=
C
.
paddle_create_optimizer
(
(
*
C
.
uchar
)(
&
c
[
0
])
,
cptr
,
C
.
int
(
len
(
c
)),
C
.
paddle_element_type
(
p
.
ElementType
),
cbuffer
,
...
...
go/pserver/service.go
浏览文件 @
b1cbdf03
...
...
@@ -17,12 +17,11 @@ package pserver
import
(
"bufio"
"bytes"
"crypto/md5"
"encoding/gob"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"hash/crc32"
"io/ioutil"
"os"
"path"
...
...
@@ -40,7 +39,7 @@ type ElementType int
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var
ErrCheckpointNotFound
=
errors
.
New
(
"checkpoint not found"
)
var
ErrCheckpointNotFound
=
errors
.
New
(
"checkpoint not found
in etcd
"
)
// RPC error message.
const
(
...
...
@@ -76,7 +75,7 @@ type ParameterWithConfig struct {
type
checkpointMeta
struct
{
UUID
string
`json:"uuid"`
Path
string
`json:"path"`
MD5
string
`json:"md5
"`
CRC32
uint32
`json:"crc32
"`
Timestamp
int64
`json:"timestamp"`
}
...
...
@@ -92,7 +91,7 @@ type Service struct {
idx
int
checkpointInterval
time
.
Duration
checkpointPath
string
client
*
EtcdClient
client
KVStore
mu
sync
.
Mutex
optMap
map
[
string
]
*
optimizer
...
...
@@ -104,7 +103,12 @@ type parameterCheckpoint struct {
State
[]
byte
}
func
loadMeta
(
e
*
EtcdClient
,
idx
int
)
(
meta
checkpointMeta
,
err
error
)
{
type
KVStore
interface
{
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
,
withLease
bool
)
error
}
func
loadMeta
(
e
KVStore
,
idx
int
)
(
meta
checkpointMeta
,
err
error
)
{
v
,
err
:=
e
.
GetKey
(
PsCheckpoint
+
strconv
.
Itoa
(
idx
),
3
*
time
.
Second
)
if
err
!=
nil
{
return
...
...
@@ -123,7 +127,7 @@ func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
}
// LoadCheckpoint loads checkpoint from file.
func
LoadCheckpoint
(
e
*
EtcdClient
,
idx
int
)
(
Checkpoint
,
error
)
{
func
LoadCheckpoint
(
e
KVStore
,
idx
int
)
(
Checkpoint
,
error
)
{
log
.
Info
(
"Loading checkpoint"
,
"pserver index"
,
idx
)
defer
traceTime
(
time
.
Now
(),
"load checkpoint"
)
...
...
@@ -137,11 +141,8 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
return
nil
,
err
}
// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
content
))
if
md5
!=
cpMeta
.
MD5
{
crc32
:=
crc32
.
ChecksumIEEE
(
content
)
if
crc32
!=
cpMeta
.
CRC32
{
return
nil
,
errors
.
New
(
WrongChecksum
)
}
...
...
@@ -150,12 +151,13 @@ func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
if
err
=
dec
.
Decode
(
&
cp
);
err
!=
nil
{
return
nil
,
err
}
return
cp
,
nil
}
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func
NewService
(
idx
int
,
interval
time
.
Duration
,
path
string
,
client
*
EtcdClient
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
func
NewService
(
idx
int
,
interval
time
.
Duration
,
path
string
,
client
KVStore
,
cp
Checkpoint
)
(
*
Service
,
error
)
{
s
:=
&
Service
{
idx
:
idx
,
checkpointInterval
:
interval
,
...
...
@@ -173,6 +175,7 @@ func NewService(idx int, interval time.Duration, path string, client *EtcdClient
}
s
.
optMap
[
p
.
Param
.
Name
]
=
newOptimizer
(
p
,
item
.
State
)
}
close
(
s
.
initialized
)
}
return
s
,
nil
}
...
...
@@ -221,7 +224,7 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
for
range
t
{
err
:=
s
.
checkpoint
()
if
err
!=
nil
{
log
.
Error
(
"
finish init params
error"
,
log
.
Ctx
{
"error"
:
err
})
log
.
Error
(
"
checkpoint
error"
,
log
.
Ctx
{
"error"
:
err
})
}
}
}()
...
...
@@ -274,6 +277,7 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
parameter
.
Name
=
name
parameter
.
ElementType
=
opt
.
elementType
parameter
.
Content
=
opt
.
GetWeights
()
log
.
Info
(
"sending parameter to the trainer"
,
"name"
,
parameter
.
Name
,
"size"
,
len
(
parameter
.
Content
),
"type"
,
parameter
.
ElementType
)
return
nil
}
...
...
@@ -354,20 +358,29 @@ func (s *Service) checkpoint() (err error) {
oldMeta
,
err
:=
loadMeta
(
s
.
client
,
s
.
idx
)
if
err
==
ErrCheckpointNotFound
{
log
.
Info
(
"
Do not have existing checkpoint.
"
)
log
.
Info
(
"
old meta not found, skip removing old meta
"
)
err
=
nil
}
else
if
err
==
nil
{
log
.
Info
(
"removing old meta"
)
if
oldMeta
.
Path
!=
""
{
rmErr
:=
os
.
Remove
(
oldMeta
.
Path
)
if
rmErr
!=
nil
{
// log error, but still treat checkpoint as
// successful.
log
.
Error
(
"remove old meta file error"
,
log
.
Ctx
{
"error"
:
rmErr
})
}
}
}
if
err
!=
nil
{
return
}
h
:=
md5
.
New
()
md5
:=
hex
.
EncodeToString
(
h
.
Sum
(
buf
.
Bytes
()))
crc32
:=
crc32
.
ChecksumIEEE
(
buf
.
Bytes
())
cpMeta
:=
checkpointMeta
{
UUID
:
id
,
Timestamp
:
time
.
Now
()
.
UnixNano
(),
MD5
:
md5
,
CRC32
:
crc32
,
Path
:
p
,
}
...
...
@@ -381,14 +394,5 @@ func (s *Service) checkpoint() (err error) {
return
}
if
oldMeta
.
Path
!=
""
{
rmErr
:=
os
.
Remove
(
oldMeta
.
Path
)
if
rmErr
!=
nil
{
// log error, but still treat checkpoint as
// successful.
log
.
Error
(
"remove old meta file error"
,
log
.
Ctx
{
"error"
:
rmErr
})
}
}
return
}
go/pserver/service_internal_test.go
0 → 100644
浏览文件 @
b1cbdf03
package
pserver
import
(
"bytes"
"encoding/binary"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
const
testDir
=
"./test_data"
type
myKV
struct
{
m
map
[
string
][]
byte
}
func
(
m
*
myKV
)
GetKey
(
key
string
,
timeout
time
.
Duration
)
([]
byte
,
error
)
{
if
m
.
m
==
nil
{
m
.
m
=
make
(
map
[
string
][]
byte
)
}
return
m
.
m
[
key
],
nil
}
func
(
m
*
myKV
)
PutKey
(
key
string
,
value
[]
byte
,
timeout
time
.
Duration
,
withLease
bool
)
error
{
if
m
.
m
==
nil
{
m
.
m
=
make
(
map
[
string
][]
byte
)
}
m
.
m
[
key
]
=
value
return
nil
}
func
TestCheckpoint
(
t
*
testing
.
T
)
{
kv
:=
&
myKV
{}
s
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
nil
)
assert
.
Nil
(
t
,
err
)
err
=
s
.
checkpoint
()
assert
.
Nil
(
t
,
err
)
_
,
err
=
LoadCheckpoint
(
kv
,
0
)
assert
.
Nil
(
t
,
err
)
}
func
float32ToByte
(
f
float32
)
[]
byte
{
var
buf
bytes
.
Buffer
err
:=
binary
.
Write
(
&
buf
,
binary
.
LittleEndian
,
f
)
if
err
!=
nil
{
fmt
.
Println
(
"binary.Write failed:"
,
err
)
}
return
buf
.
Bytes
()
}
func
TestCheckpointWithData
(
t
*
testing
.
T
)
{
kv
:=
&
myKV
{}
s
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
nil
)
assert
.
Nil
(
t
,
err
)
var
content
[]
byte
for
i
:=
0
;
i
<
50000
;
i
++
{
content
=
append
(
content
,
float32ToByte
(
float32
(
i
))
...
)
}
p1
:=
Parameter
{
Name
:
"p1"
,
ElementType
:
1
,
Content
:
content
}
err
=
s
.
InitParam
(
ParameterWithConfig
{
Param
:
p1
},
nil
)
assert
.
Nil
(
t
,
err
)
err
=
s
.
FinishInitParams
(
0
,
nil
)
assert
.
Nil
(
t
,
err
)
var
p2
Parameter
err
=
s
.
GetParam
(
p1
.
Name
,
&
p2
)
assert
.
Nil
(
t
,
err
)
assert
.
Equal
(
t
,
p1
,
p2
)
err
=
s
.
checkpoint
()
assert
.
Nil
(
t
,
err
)
cp
,
err
:=
LoadCheckpoint
(
kv
,
0
)
assert
.
Nil
(
t
,
err
)
s1
,
err
:=
NewService
(
0
,
time
.
Hour
,
testDir
,
kv
,
cp
)
assert
.
Nil
(
t
,
err
)
var
p3
Parameter
err
=
s1
.
GetParam
(
p1
.
Name
,
&
p3
)
assert
.
Nil
(
t
,
err
)
assert
.
Equal
(
t
,
p1
,
p3
)
}
go/pserver/service_test.go
浏览文件 @
b1cbdf03
...
...
@@ -178,7 +178,3 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Wait
()
}
func
TestCheckpointSpeed
(
t
*
testing
.
T
)
{
//TODO(zhihong): test speed
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录