提交 d1cda903 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #3222 from helinwang/checkpoint

Fix pserver save / load checkpoint
...@@ -32,7 +32,7 @@ import ( ...@@ -32,7 +32,7 @@ import (
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout") dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout")
...@@ -60,12 +60,12 @@ func main() { ...@@ -60,12 +60,12 @@ func main() {
idx, err = e.Register(*port) idx, err = e.Register(*port)
candy.Must(err) candy.Must(err)
cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) cp, err = pserver.LoadCheckpoint(e, idx)
if err != nil { if err != nil {
if err == pserver.ErrCheckpointNotFound { if err == pserver.ErrCheckpointNotFound {
log.Infof("Could not find the pserver checkpoint.") log.Infof("Could not find the pserver checkpoint.")
} else { } else {
log.Errorf("Fetch checkpoint failed, %s", err) panic(err)
} }
} }
} }
......
hash: 2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-07-29T07:34:48.722757905+08:00 updated: 2017-08-03T21:46:51.744995189Z
imports: imports:
- name: github.com/beorn7/perks - name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
...@@ -145,6 +145,8 @@ imports: ...@@ -145,6 +145,8 @@ imports:
version: a1dba9ce8baed984a2495b658c82687f8157b98f version: a1dba9ce8baed984a2495b658c82687f8157b98f
subpackages: subpackages:
- xfs - xfs
- name: github.com/satori/go.uuid
version: 879c5887cd475cd7864858769793b2ceb0d44feb
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: a3f95b5c423586578a4e099b11a46c2479628cac version: a3f95b5c423586578a4e099b11a46c2479628cac
- name: github.com/topicai/candy - name: github.com/topicai/candy
......
...@@ -14,11 +14,13 @@ import: ...@@ -14,11 +14,13 @@ import:
version: ^1.0.0 version: ^1.0.0
- package: github.com/topicai/candy - package: github.com/topicai/candy
- package: golang.org/x/crypto - package: golang.org/x/crypto
vcs: git
repo: https://github.com/golang/crypto.git repo: https://github.com/golang/crypto.git
- package: golang.org/x/sys
vcs: git vcs: git
- package: golang.org/x/sys
repo: https://github.com/golang/sys.git repo: https://github.com/golang/sys.git
- package: golang.org/x/text
vcs: git vcs: git
- package: golang.org/x/text
repo: https://github.com/golang/text.git repo: https://github.com/golang/text.git
vcs: git
- package: github.com/satori/go.uuid
version: v1.1.0
...@@ -59,7 +59,7 @@ func initClient() [numPserver]int { ...@@ -59,7 +59,7 @@ func initClient() [numPserver]int {
go func(l net.Listener) { go func(l net.Listener) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -206,6 +206,7 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { ...@@ -206,6 +206,7 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
if err != nil { if err != nil {
return []byte{}, err return []byte{}, err
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
return []byte{}, nil return []byte{}, nil
...@@ -215,9 +216,14 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { ...@@ -215,9 +216,14 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
} }
// PutKey put into etcd with value by key specified // PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
ctx, cancel := context.WithTimeout(context.Background(), timeout) ctx, cancel := context.WithTimeout(context.Background(), timeout)
_, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease())) var err error
if withLease {
_, err = e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
} else {
_, err = e.client.Put(ctx, key, string(value))
}
cancel() cancel()
return err return err
} }
......
...@@ -32,6 +32,7 @@ type optimizer struct { ...@@ -32,6 +32,7 @@ type optimizer struct {
opt *C.struct_paddle_optimizer opt *C.struct_paddle_optimizer
elementType ElementType elementType ElementType
contentLen int contentLen int
config []byte
} }
func cArrayToSlice(p unsafe.Pointer, len int) []byte { func cArrayToSlice(p unsafe.Pointer, len int) []byte {
...@@ -70,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer ...@@ -70,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer
cstate = unsafe.Pointer(&s[0]) cstate = unsafe.Pointer(&s[0])
} }
o.config = c
o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)),
C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s))) C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s)))
return o return o
......
...@@ -25,11 +25,13 @@ import ( ...@@ -25,11 +25,13 @@ import (
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"path/filepath" "path"
"strconv" "strconv"
"sync" "sync"
"time" "time"
uuid "github.com/satori/go.uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -42,9 +44,9 @@ var ErrCheckpointNotFound = errors.New("checkpoint not found") ...@@ -42,9 +44,9 @@ var ErrCheckpointNotFound = errors.New("checkpoint not found")
// RPC error message. // RPC error message.
const ( const (
AlreadyInitialized = "pserver already initialized" AlreadyInitialized = "pserver already initialized"
Uninitialized = "pserver not fully initialized" Uninitialized = "pserver not fully initialized"
CheckpointMD5Failed = "checkpoint file MD5 validation failed" WrongChecksum = "checkpoint file checksum validation failed"
) )
// Supported element types. // Supported element types.
...@@ -73,11 +75,12 @@ type ParameterWithConfig struct { ...@@ -73,11 +75,12 @@ type ParameterWithConfig struct {
// checkpointMeta saves checkpoint metadata // checkpointMeta saves checkpoint metadata
type checkpointMeta struct { type checkpointMeta struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
Path string `json:"path"`
MD5 string `json:"md5"` MD5 string `json:"md5"`
Timestamp int64 `json:"timestamp"` Timestamp int64 `json:"timestamp"`
} }
// Checkpoint is the pserver shard persist in file // Checkpoint is the pserver shard persist in file.
type Checkpoint []parameterCheckpoint type Checkpoint []parameterCheckpoint
// Gradient is the gradient of the parameter. // Gradient is the gradient of the parameter.
...@@ -90,50 +93,58 @@ type Service struct { ...@@ -90,50 +93,58 @@ type Service struct {
checkpointInterval time.Duration checkpointInterval time.Duration
checkpointPath string checkpointPath string
client *EtcdClient client *EtcdClient
mu sync.Mutex
optMap map[string]*optimizer mu sync.Mutex
optMap map[string]*optimizer
} }
// parameterCheckpoint saves parameter checkpoint // parameterCheckpoint saves parameter checkpoint.
type parameterCheckpoint struct { type parameterCheckpoint struct {
ParameterWithConfig ParameterWithConfig
State []byte State []byte
} }
// NewCheckpointFromFile loads parameters and state from checkpoint file func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) { v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
if err != nil { if err != nil {
return nil, err return
} }
if len(v) == 0 { if len(v) == 0 {
return nil, ErrCheckpointNotFound err = ErrCheckpointNotFound
return
} }
var cpMeta checkpointMeta if err = json.Unmarshal(v, &meta); err != nil {
if err = json.Unmarshal(v, &cpMeta); err != nil { return
return nil, err
} }
fn := filepath.Join(cpPath, cpMeta.UUID) return
if _, err = os.Stat(fn); os.IsNotExist(err) { }
// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
cpMeta, err := loadMeta(e, idx)
if err != nil {
return nil, err return nil, err
} }
content, err := ioutil.ReadFile(fn)
content, err := ioutil.ReadFile(cpMeta.Path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO(helin): change MD5 to CRC since CRC is better for file
// checksum in our use case (emphasize speed over security).
h := md5.New() h := md5.New()
md5 := hex.EncodeToString(h.Sum(content)) md5 := hex.EncodeToString(h.Sum(content))
if md5 != cpMeta.MD5 { if md5 != cpMeta.MD5 {
return nil, errors.New(CheckpointMD5Failed) return nil, errors.New(WrongChecksum)
} }
dec := gob.NewDecoder(bytes.NewReader(content)) dec := gob.NewDecoder(bytes.NewReader(content))
cp := Checkpoint{} var cp Checkpoint
if err = dec.Decode(cp); err != nil { if err = dec.Decode(&cp); err != nil {
return nil, err return nil, err
} }
return cp, nil return cp, nil
...@@ -193,6 +204,15 @@ func (s *Service) FinishInitParams(_ int, _ *int) error { ...@@ -193,6 +204,15 @@ func (s *Service) FinishInitParams(_ int, _ *int) error {
} }
close(s.initialized) close(s.initialized)
go func() {
t := time.Tick(s.checkpointInterval)
for range t {
err := s.checkpoint()
if err != nil {
log.Errorln(err)
}
}
}()
return nil return nil
} }
...@@ -240,23 +260,36 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { ...@@ -240,23 +260,36 @@ func (s *Service) GetParam(name string, parameter *Parameter) error {
return nil return nil
} }
// pserver save checkpoint func traceTime(start time.Time, name string) {
func (s *Service) doCheckpoint() (err error) { elapsed := time.Since(start)
<-s.initialized log.Infof("%s took %v", name, elapsed)
s.mu.Lock() }
defer s.mu.Unlock()
// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func (s *Service) checkpoint() (err error) {
log.Infoln("Begin save checkpoint.")
defer traceTime(time.Now(), "save checkpoint")
s.mu.Lock()
cp := make([]parameterCheckpoint, len(s.optMap)) cp := make([]parameterCheckpoint, len(s.optMap))
index := 0 index := 0
// TODO(helin): write checkpoint incrementally to reduce memory
// footprint during checkpoint.
for name, opt := range s.optMap { for name, opt := range s.optMap {
var pc parameterCheckpoint var pc parameterCheckpoint
pc.Param.Name = name pc.Param.Name = name
pc.Param.ElementType = opt.elementType pc.Param.ElementType = opt.elementType
pc.Param.Content = opt.GetWeights() pc.Param.Content = opt.GetWeights()
pc.Config = opt.config
pc.State = opt.GetStates() pc.State = opt.GetStates()
cp[index] = pc cp[index] = pc
index++ index++
} }
s.mu.Unlock()
var buf bytes.Buffer var buf bytes.Buffer
encoder := gob.NewEncoder(&buf) encoder := gob.NewEncoder(&buf)
err = encoder.Encode(cp) err = encoder.Encode(cp)
...@@ -264,32 +297,9 @@ func (s *Service) doCheckpoint() (err error) { ...@@ -264,32 +297,9 @@ func (s *Service) doCheckpoint() (err error) {
return return
} }
cpMeta := checkpointMeta{} id := uuid.NewV4().String()
cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) p := path.Join(s.checkpointPath, id)
cpMeta.Timestamp = time.Now().UnixNano() f, err := os.Create(p)
h := md5.New()
cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
cpMetajson, err := json.Marshal(cpMeta)
if err != nil {
return
}
err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
if err != nil {
return
}
if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
log.Info("checkpoint does not exists.")
} else {
err = os.Remove(cpMeta.UUID)
if err != nil {
log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
} else {
log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
}
}
f, err := os.Create(cpMeta.UUID)
if err != nil { if err != nil {
return return
} }
...@@ -317,5 +327,43 @@ func (s *Service) doCheckpoint() (err error) { ...@@ -317,5 +327,43 @@ func (s *Service) doCheckpoint() (err error) {
return return
} }
oldMeta, err := loadMeta(s.client, s.idx)
if err == ErrCheckpointNotFound {
log.Infoln("Do not have existing checkpoint.")
err = nil
}
if err != nil {
return
}
h := md5.New()
md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
cpMeta := checkpointMeta{
UUID: id,
Timestamp: time.Now().UnixNano(),
MD5: md5,
Path: p,
}
json, err := json.Marshal(cpMeta)
if err != nil {
return
}
err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false)
if err != nil {
return
}
if oldMeta.Path != "" {
rmErr := os.Remove(oldMeta.Path)
if rmErr != nil {
// log error, but still treat checkpoint as
// successful.
log.Errorln(rmErr)
}
}
return return
} }
...@@ -30,7 +30,7 @@ const ( ...@@ -30,7 +30,7 @@ const (
func TestServiceFull(t *testing.T) { func TestServiceFull(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
...@@ -102,7 +102,7 @@ func TestServiceFull(t *testing.T) { ...@@ -102,7 +102,7 @@ func TestServiceFull(t *testing.T) {
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
...@@ -119,7 +119,7 @@ func TestMultipleInit(t *testing.T) { ...@@ -119,7 +119,7 @@ func TestMultipleInit(t *testing.T) {
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
err = s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.Fatal(err) t.Fatal(err)
...@@ -128,7 +128,7 @@ func TestUninitialized(t *testing.T) { ...@@ -128,7 +128,7 @@ func TestUninitialized(t *testing.T) {
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
var cp pserver.Checkpoint var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp) s, err := pserver.NewService(0, time.Hour, "", nil, cp)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册