From f1330e216a1b8130bb578b69ff2d6a67357cdd1b Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Mon, 3 Jul 2017 23:20:39 +0800 Subject: [PATCH] "saving checkpoint" --- go/pserver/service.go | 79 +++++++++++++++++++++++++++++++++++--- go/pserver/service_test.go | 6 +++ 2 files changed, 80 insertions(+), 5 deletions(-) diff --git a/go/pserver/service.go b/go/pserver/service.go index a5ff8629033..a4cf3e47507 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -1,9 +1,19 @@ package pserver import ( + "bufio" + "bytes" + "crypto/md5" + "encoding/gob" + "encoding/hex" "errors" "fmt" + "os" + "strconv" "sync" + "time" + + log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. @@ -14,6 +24,10 @@ const ( Uninitialized = "pserver not fully initialized" ) +const ( + checkpoint_path = "/checkpoints/" +) + // Supported element types const ( Int32 ElementType = iota @@ -53,6 +67,24 @@ type Service struct { optMap map[string]*optimizer } +type Checkpoint struct { + uuid string + md5sum string + timestamp string +} + +//serialize ParameterWithConfig to byte stream +func GetBytes(content ...interface{}) ([]byte, error) { + + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(content) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + // NewService creates a new service, will bypass etcd registration if no // endpoints specified. func NewService(idx int) (*Service, error) { @@ -143,13 +175,50 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { // Save tells the parameter server to save parameters. func (s *Service) Save(path string, dummy *int) error { + //FIXME: checkpoint is only used by pserver + // and has a constant path of */checkpoints/{pserver_idx}* <-s.initialized - for opt, ok := range s.optMap { - if ok != nil { - return fmt.Errorf("parameter optimizerMap error: ", ok) + s.mu.Lock() + defer s.mu.Unlock() + var paramWithConfig ParameterWithConfig + for name, opt := range s.optMap { + paramWithConfig.Param.Name = name + paramWithConfig.Param.ElementType = opt.elementType + paramWithConfig.Param.Content = opt.GetWeights() + paramWithConfig.State = opt.GetStates() + content, err := GetBytes(paramWithConfig) + if err != nil { + log.Errorln(err) + } + ck := Checkpoint{} + h := md5.New() + ck.md5sum = hex.EncodeToString(h.Sum(content)) + ck.timestamp = time.Now().String() + ck.uuid = checkpoint_path + strconv.Itoa(s.idx) + ckbytes, err := GetBytes(ck) + if err != nil { + log.Errorln(err) + } + // TODO: according design doc, need to save uuid to etcd in json format + // {\"uuid\": [UUID], \"md5\", \"MD5 sum\", \"timestamp\": xxxx} + log.Infof("parameter checkpoint %s", ckbytes) + + if _, err = os.Stat(ck.uuid); os.IsNotExist(err) { + log.Info("checkpoint not exists.") + } else { + err = os.Remove(ck.uuid) + log.Infof("remove %s", ck.uuid) + } + f, err := os.Create(ck.uuid) + defer f.Close() + if err != nil { + log.Errorln(err) + } + writer := bufio.NewWriter(f) + _, err = writer.Write(content) + if err != nil { + log.Errorln(err) } - state := opt.GetStates() - weights := opt.GetWeights() } return nil } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index f86619447c2..28956e4d851 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -79,6 +79,8 @@ func TestServiceFull(t *testing.T) { if !reflect.DeepEqual(param1, p) { t.FailNow() } + var dummy int + s.Save("", &dummy) } func TestMultipleInit(t *testing.T) { @@ -166,3 +168,7 @@ func TestBlockUntilInitialized(t *testing.T) { wg.Wait() } + +func TestCheckpointSpeed(t *testing.T) { + //TODO: test speed +} -- GitLab