diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index f9cd8f87e8f2e715c87834ee08482be0f511f681..bec5775d540729000ab2dd3002600f0a92619d70 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -32,7 +32,7 @@ import ( func main() { port := flag.Int("port", 0, "port of the pserver") - index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") + index := flag.Int("index", -1, "index of the pserver, set to -1 if use etcd for auto pserver index registry") etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") dialTimeout := flag.Duration("dial-timeout", 5*time.Second, "dial timeout") @@ -60,12 +60,12 @@ func main() { idx, err = e.Register(*port) candy.Must(err) - cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) + cp, err = pserver.LoadCheckpoint(e, idx) if err != nil { if err == pserver.ErrCheckpointNotFound { log.Infof("Could not find the pserver checkpoint.") } else { - log.Errorf("Fetch checkpoint failed, %s", err) + panic(err) } } } diff --git a/go/glide.lock b/go/glide.lock index 1f16abdf66422abcd0ab7987cab3499d02cf1b9c..be1fb24d772a6524cb798c6169c23ff03e9fed7b 100644 --- a/go/glide.lock +++ b/go/glide.lock @@ -1,5 +1,5 @@ -hash: 2a1c0eca5c07a130e3d224f9821f96cfa37a39bf6bce141c855bbc57ef569f1c -updated: 2017-07-29T07:34:48.722757905+08:00 +hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582 +updated: 2017-08-03T21:46:51.744995189Z imports: - name: github.com/beorn7/perks version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 @@ -145,6 +145,8 @@ imports: version: a1dba9ce8baed984a2495b658c82687f8157b98f subpackages: - xfs +- name: github.com/satori/go.uuid + version: 879c5887cd475cd7864858769793b2ceb0d44feb - name: github.com/sirupsen/logrus version: a3f95b5c423586578a4e099b11a46c2479628cac - name: github.com/topicai/candy diff --git a/go/glide.yaml b/go/glide.yaml index bc23fa6ebf2c3db61e2d63e5f7e7ddcb595dfed0..a90e71b615de92d64c79823e2a04c46001963932 100644 --- a/go/glide.yaml +++ b/go/glide.yaml @@ -14,11 +14,13 @@ import: version: ^1.0.0 - package: github.com/topicai/candy - package: golang.org/x/crypto - vcs: git repo: https://github.com/golang/crypto.git -- package: golang.org/x/sys vcs: git +- package: golang.org/x/sys repo: https://github.com/golang/sys.git -- package: golang.org/x/text vcs: git +- package: golang.org/x/text repo: https://github.com/golang/text.git + vcs: git +- package: github.com/satori/go.uuid + version: v1.1.0 diff --git a/go/pserver/client/client_test.go b/go/pserver/client/client_test.go index b630d434dca283df67f5b850b35057870fe27529..1243ebd6836550d58144b5033e2755ae8594e948 100644 --- a/go/pserver/client/client_test.go +++ b/go/pserver/client/client_test.go @@ -59,7 +59,7 @@ func initClient() [numPserver]int { go func(l net.Listener) { var cp pserver.Checkpoint - s, err := pserver.NewService(0, 1, "", nil, cp) + s, err := pserver.NewService(0, time.Hour, "", nil, cp) if err != nil { panic(err) } diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 4fb26307667295ab825d07be6c3d1d4b33f6eb8b..41f0640fc09a3265c0e11c06255c7ee834983203 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -206,6 +206,7 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { if err != nil { return []byte{}, err } + kvs := resp.Kvs if len(kvs) == 0 { return []byte{}, nil @@ -215,9 +216,14 @@ func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { } // PutKey put into etcd with value by key specified -func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { +func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) - _, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease())) + var err error + if withLease { + _, err = e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease())) + } else { + _, err = e.client.Put(ctx, key, string(value)) + } cancel() return err } diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 709160d45d98b6cf6d60f52ceb3fb33e0a0bd17d..ae7359073494bd9cb6b70b12af4daca064179556 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -32,6 +32,7 @@ type optimizer struct { opt *C.struct_paddle_optimizer elementType ElementType contentLen int + config []byte } func cArrayToSlice(p unsafe.Pointer, len int) []byte { @@ -70,6 +71,7 @@ func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer cstate = unsafe.Pointer(&s[0]) } + o.config = c o.opt = C.paddle_create_optimizer((*C.uchar)(&c[0]), C.int(len(c)), C.paddle_element_type(p.ElementType), cbuffer, C.int(paramBufferSize), (*C.char)(cstate), C.int(len(s))) return o diff --git a/go/pserver/service.go b/go/pserver/service.go index 7d297c46d03bf78d18ca9830a318968397119d3e..25751540a9a2dff043c14e0912bfab1aaa938ab4 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -25,11 +25,13 @@ import ( "fmt" "io/ioutil" "os" - "path/filepath" + "path" "strconv" "sync" "time" + uuid "github.com/satori/go.uuid" + log "github.com/sirupsen/logrus" ) @@ -42,9 +44,9 @@ var ErrCheckpointNotFound = errors.New("checkpoint not found") // RPC error message. const ( - AlreadyInitialized = "pserver already initialized" - Uninitialized = "pserver not fully initialized" - CheckpointMD5Failed = "checkpoint file MD5 validation failed" + AlreadyInitialized = "pserver already initialized" + Uninitialized = "pserver not fully initialized" + WrongChecksum = "checkpoint file checksum validation failed" ) // Supported element types. @@ -73,11 +75,12 @@ type ParameterWithConfig struct { // checkpointMeta saves checkpoint metadata type checkpointMeta struct { UUID string `json:"uuid"` + Path string `json:"path"` MD5 string `json:"md5"` Timestamp int64 `json:"timestamp"` } -// Checkpoint is the pserver shard persist in file +// Checkpoint is the pserver shard persist in file. type Checkpoint []parameterCheckpoint // Gradient is the gradient of the parameter. @@ -90,50 +93,58 @@ type Service struct { checkpointInterval time.Duration checkpointPath string client *EtcdClient - mu sync.Mutex - optMap map[string]*optimizer + + mu sync.Mutex + optMap map[string]*optimizer } -// parameterCheckpoint saves parameter checkpoint +// parameterCheckpoint saves parameter checkpoint. type parameterCheckpoint struct { ParameterWithConfig State []byte } -// NewCheckpointFromFile loads parameters and state from checkpoint file -func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) { - v, err := e.GetKey(PsPath+string(idx), 3*time.Second) +func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) { + v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second) if err != nil { - return nil, err + return } if len(v) == 0 { - return nil, ErrCheckpointNotFound + err = ErrCheckpointNotFound + return } - var cpMeta checkpointMeta - if err = json.Unmarshal(v, &cpMeta); err != nil { - return nil, err + if err = json.Unmarshal(v, &meta); err != nil { + return } - fn := filepath.Join(cpPath, cpMeta.UUID) - if _, err = os.Stat(fn); os.IsNotExist(err) { + return +} + +// LoadCheckpoint loads checkpoint from file. +func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) { + cpMeta, err := loadMeta(e, idx) + if err != nil { return nil, err } - content, err := ioutil.ReadFile(fn) + + content, err := ioutil.ReadFile(cpMeta.Path) if err != nil { return nil, err } + // TODO(helin): change MD5 to CRC since CRC is better for file + // checksum in our use case (emphasize speed over security). h := md5.New() md5 := hex.EncodeToString(h.Sum(content)) if md5 != cpMeta.MD5 { - return nil, errors.New(CheckpointMD5Failed) + return nil, errors.New(WrongChecksum) } dec := gob.NewDecoder(bytes.NewReader(content)) - cp := Checkpoint{} - if err = dec.Decode(cp); err != nil { + var cp Checkpoint + if err = dec.Decode(&cp); err != nil { return nil, err } return cp, nil @@ -193,6 +204,15 @@ func (s *Service) FinishInitParams(_ int, _ *int) error { } close(s.initialized) + go func() { + t := time.Tick(s.checkpointInterval) + for range t { + err := s.checkpoint() + if err != nil { + log.Errorln(err) + } + } + }() return nil } @@ -240,23 +260,36 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { return nil } -// pserver save checkpoint -func (s *Service) doCheckpoint() (err error) { - <-s.initialized - s.mu.Lock() - defer s.mu.Unlock() +func traceTime(start time.Time, name string) { + elapsed := time.Since(start) + log.Infof("%s took %v", name, elapsed) +} + +// checkpoint saves checkpoint to disk. +// +// checkpoint should be only called after the parameters are +// initialized. +func (s *Service) checkpoint() (err error) { + log.Infoln("Begin save checkpoint.") + defer traceTime(time.Now(), "save checkpoint") + s.mu.Lock() cp := make([]parameterCheckpoint, len(s.optMap)) index := 0 + // TODO(helin): write checkpoint incrementally to reduce memory + // footprint during checkpoint. for name, opt := range s.optMap { var pc parameterCheckpoint pc.Param.Name = name pc.Param.ElementType = opt.elementType pc.Param.Content = opt.GetWeights() + pc.Config = opt.config pc.State = opt.GetStates() cp[index] = pc index++ } + s.mu.Unlock() + var buf bytes.Buffer encoder := gob.NewEncoder(&buf) err = encoder.Encode(cp) @@ -264,32 +297,9 @@ func (s *Service) doCheckpoint() (err error) { return } - cpMeta := checkpointMeta{} - cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) - cpMeta.Timestamp = time.Now().UnixNano() - h := md5.New() - cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes())) - - cpMetajson, err := json.Marshal(cpMeta) - if err != nil { - return - } - - err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second) - if err != nil { - return - } - if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) { - log.Info("checkpoint does not exists.") - } else { - err = os.Remove(cpMeta.UUID) - if err != nil { - log.Infof("Removing checkpoint %s failed", cpMeta.UUID) - } else { - log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) - } - } - f, err := os.Create(cpMeta.UUID) + id := uuid.NewV4().String() + p := path.Join(s.checkpointPath, id) + f, err := os.Create(p) if err != nil { return } @@ -317,5 +327,43 @@ func (s *Service) doCheckpoint() (err error) { return } + oldMeta, err := loadMeta(s.client, s.idx) + if err == ErrCheckpointNotFound { + log.Infoln("Do not have existing checkpoint.") + err = nil + } + + if err != nil { + return + } + + h := md5.New() + md5 := hex.EncodeToString(h.Sum(buf.Bytes())) + cpMeta := checkpointMeta{ + UUID: id, + Timestamp: time.Now().UnixNano(), + MD5: md5, + Path: p, + } + + json, err := json.Marshal(cpMeta) + if err != nil { + return + } + + err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false) + if err != nil { + return + } + + if oldMeta.Path != "" { + rmErr := os.Remove(oldMeta.Path) + if rmErr != nil { + // log error, but still treat checkpoint as + // successful. + log.Errorln(rmErr) + } + } + return } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 988f3b5acb82a95aeb54af2b8b0e4d39a458291a..be648cd1e83e4f7790edac5842db432fb4870072 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -30,7 +30,7 @@ const ( func TestServiceFull(t *testing.T) { var cp pserver.Checkpoint - s, err := pserver.NewService(0, 1, "", nil, cp) + s, err := pserver.NewService(0, time.Hour, "", nil, cp) if err != nil { t.Error(err) } @@ -102,7 +102,7 @@ func TestServiceFull(t *testing.T) { func TestMultipleInit(t *testing.T) { var cp pserver.Checkpoint - s, err := pserver.NewService(0, 1, "", nil, cp) + s, err := pserver.NewService(0, time.Hour, "", nil, cp) if err != nil { t.Fatal(err) } @@ -119,7 +119,7 @@ func TestMultipleInit(t *testing.T) { func TestUninitialized(t *testing.T) { var cp pserver.Checkpoint - s, err := pserver.NewService(0, 1, "", nil, cp) + s, err := pserver.NewService(0, time.Hour, "", nil, cp) err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.Fatal(err) @@ -128,7 +128,7 @@ func TestUninitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) { var cp pserver.Checkpoint - s, err := pserver.NewService(0, 1, "", nil, cp) + s, err := pserver.NewService(0, time.Hour, "", nil, cp) if err != nil { t.Error(err) }