From 40295b9ed9ede878c930c6fc9ce6719c8270db07 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Fri, 7 Jul 2017 19:56:29 +0800 Subject: [PATCH] "fix pserver saving etcd" --- go/cmd/pserver/pserver.go | 5 +- go/pserver/etcd_client.go | 13 +++ go/pserver/optimizer.go | 4 +- go/pserver/service.go | 170 +++++++++++++++++++++---------------- go/pserver/service_test.go | 5 +- 5 files changed, 116 insertions(+), 81 deletions(-) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 31ef450f032..56c1f6e1db6 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -20,6 +20,8 @@ func main() { "comma separated endpoint string for pserver to connect to etcd") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") + checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") + checkpointInterval := flag.Int("checkpoint-interval", "10", "save checkpoint per interval seconds") logLevel := flag.String("log-level", "info", "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() @@ -31,6 +33,7 @@ func main() { log.SetLevel(level) var idx int + var cp pserver.Checkpoint if *index >= 0 { idx = *index } else { @@ -42,7 +45,7 @@ func main() { } } - s, err := pserver.NewService(idx) + s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) if err != nil { panic(err) } diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 37b8d522c1b..20041d04d08 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -18,6 +18,8 @@ const ( PsDesired = "/ps_desired" // PsAddr is the base dir for pserver to store their addr PsPath = "/ps/" + // PsCheckpoint is the etcd path for store checkpoints information + PsCheckpoint = "/checkpoints/" ) // EtcdClient is the etcd client that the pserver uses for fault @@ -186,3 +188,14 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { return idx, nil } + +// PutKey put into etcd with value by key specified +func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { + ctx, err := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) + _, err = e.Put(ctx, key, value) + cancel() + if err != nil { + return err + } + return nil +} diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 1c84e728e0b..2d7882d1a75 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -35,12 +35,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { return (*[1 << 30]byte)(p)[:len:len] } -func newOptimizer(paramWithConfigs ParameterWithConfig) *optimizer { +func newOptimizer(paramWithConfigs ParameterWithConfig, State []byte) *optimizer { o := &optimizer{} o.elementType = paramWithConfigs.Param.ElementType p := paramWithConfigs.Param c := paramWithConfigs.Config - s := paramWithConfigs.State + s := State log.WithFields(log.Fields{ "ElementType": p.ElementType, "ParamSize": len(p.Content), diff --git a/go/pserver/service.go b/go/pserver/service.go index d1d041de59e..f27feb247d8 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -5,10 +5,11 @@ import ( "bytes" "crypto/md5" "encoding/gob" - "encoding/hex" + "encoding/json" "errors" "fmt" "os" + "path/filepath" "strconv" "sync" "time" @@ -26,10 +27,6 @@ const ( Uninitialized = "pserver not fully initialized" ) -const ( - checkpoint_path = "./checkpoints/" -) - // Supported element types const ( Int32 ElementType = iota @@ -51,49 +48,68 @@ type Parameter struct { type ParameterWithConfig struct { Param Parameter Config []byte // parameter configuration in Proto Buffer format - State []byte // parameter training state } +// Checkpoint of Parameter and State +type parameterCheckPoint struct { + ParamConfig ParameterWithConfig + State []byte +} + +// checkpoint signature +type checkpointMeta struct { + UUID string `json:"uuid"` + Md5sum string `json:"md5sum"` + Timestamp string `json:"timestamp"` +} + +// Checkpoint is the pserver shard persist in file +type Checkpoint []parameterCheckPoint + // Gradient is the gradient of the parameter. -type Gradient Parameter // Service is the RPC service for pserver. type Service struct { - initialized chan struct{} - idx int - - mu sync.Mutex - optMap map[string]*optimizer + initialized chan struct{} + idx int + checkpointInterval int + checkpointPath string + client *EtcdClient + mu sync.Mutex + optMap map[string]*optimizer } -type checkpoint struct { - Uuid string - Md5sum string - Timestamp string -} +// //serialize ParameterWithConfig to byte stream +// func GetBytes(content ...interface{}) ([]byte, error) { -//serialize ParameterWithConfig to byte stream -func GetBytes(content ...interface{}) ([]byte, error) { - - var buf bytes.Buffer - encoder := gob.NewEncoder(&buf) - err := encoder.Encode(content) - if err != nil { - return nil, err - } - return buf.Bytes(), nil -} +// var buf bytes.Buffer +// encoder := gob.NewEncoder(&buf) +// err := encoder.Encode(content) +// if err != nil { +// return nil, err +// } +// return buf.Bytes(), nil +// } // NewService creates a new service, will bypass etcd registration if no // endpoints specified. -func NewService(idx int) (*Service, error) { +func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { s := &Service{ - idx: idx, + idx: idx, + checkpointInterval: time.Second * time.Duration(seconds), + checkpointPath: path, + client: client, } s.optMap = make(map[string]*optimizer) s.initialized = make(chan struct{}) - gob.Register(ParameterWithConfig{}) - gob.Register(checkpoint{}) + + if cp != nil { + for _, item := range cp { + p := item.ParamConfig + st := item.State + s.optMap[p.Param.Name] = newOptimizer(p, st) + } + } return s, nil } @@ -174,53 +190,57 @@ func (s *Service) GetParam(name string, parameter *Parameter) error { return nil } -// Save tells the parameter server to save parameters. -func (s *Service) Save(path string, dummy *int) error { - //FIXME: checkpoint is only used by pserver - // and has a constant path of */checkpoints/{pserver_idx}* +// pserver save checkpoint +func (s *Service) doCheckpoint() error { <-s.initialized s.mu.Lock() defer s.mu.Unlock() - var paramWithConfig ParameterWithConfig + + cp := make([]parameterCheckPoint, 0, len(s.optMap)) + index := 0 for name, opt := range s.optMap { - paramWithConfig.Param.Name = name - paramWithConfig.Param.ElementType = opt.elementType - paramWithConfig.Param.Content = opt.GetWeights() - paramWithConfig.State = opt.GetStates() - content, err := GetBytes(paramWithConfig) - if err != nil { - log.Errorln(err) - } - ck := checkpoint{} - h := md5.New() - ck.Md5sum = hex.EncodeToString(h.Sum(content)) - ck.Timestamp = time.Now().String() - ck.Uuid = checkpoint_path + strconv.Itoa(s.idx) - ckbytes, err := GetBytes(ck) - if err != nil { - log.Errorln(err) - } - // TODO: according design doc, need to save Uuid to etcd in json format - // {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx} - log.Infof("parameter checkpoint %s", ckbytes) - - if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) { - log.Info("checkpoint not exists.") - } else { - err = os.Remove(ck.Uuid) - log.Infof("remove %s", ck.Uuid) - } - f, err := os.Create(ck.Uuid) - defer f.Close() - if err != nil { - log.Errorln(err) - } - writer := bufio.NewWriter(f) - _, err = writer.Write(content) - writer.Flush() - if err != nil { - log.Errorln(err) - } + var pc parameterCheckPoint + pc.ParamConfig.Param.Name = name + pc.ParamConfig.Param.ElementType = opt.elementType + pc.ParamConfig.Param.Content = opt.GetWeights() + pc.State = opt.GetStates() + cp[index] = pc + index++ + } + var buf bytes.Buffer + encoder := gob.NewEncoder(&buf) + err := encoder.Encode(cp) + if err != nil { + return err + } + + cpMeta := checkpointMeta{} + cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) + cpMeta.Timestamp = time.Now().String() + h := md5.New() + cpMeta.Md5sum = h.Sum(buf.Bytes()) + + cpMetajson, err := json.Marshal(cpMeta) + s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) + if err != nil { + return err + } + if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) { + log.Info("checkpoint does not exists.") + } else { + err = os.Remove(cpMeta.UUID) + log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) + } + f, err := os.Create(cpMeta.UUID) + defer f.Close() + if err != nil { + log.Errorln(err) + } + writer := bufio.NewWriter(f) + _, err = writer.Write(buf.Bytes()) + writer.Flush() + if err != nil { + log.Errorln(err) } return nil } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 65a791ae477..75d4732ea78 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -15,7 +15,8 @@ const ( ) func TestServiceFull(t *testing.T) { - s, err := pserver.NewService(0) + var cp pserver.Checkpoint + s, err := pserver.NewService(0, 1, "", nil, cp) if err != nil { t.Error(err) } @@ -83,8 +84,6 @@ func TestServiceFull(t *testing.T) { if !reflect.DeepEqual(param1, p) { t.FailNow() } - var dummy int - s.Save("", &dummy) } func TestMultipleInit(t *testing.T) { -- GitLab