From 19bfb8a1f2153f0d5368808bc09580c7e4c7b07c Mon Sep 17 00:00:00 2001 From: Yancey Date: Thu, 13 Jul 2017 09:52:26 +0800 Subject: [PATCH] PServer recovery from checkpoint (#2741) * Server recovery from checkpoint --- .gitignore | 3 ++ go/cmd/pserver/pserver.go | 39 +++++++------- go/glide.lock | 14 ++--- go/glide.yaml | 1 + go/pserver/etcd_client.go | 22 ++++++-- go/pserver/service.go | 104 ++++++++++++++++++++++++++------------ 6 files changed, 121 insertions(+), 62 deletions(-) diff --git a/.gitignore b/.gitignore index 5c2fb134ae..c84b2fc8c7 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ cmake-build-* # generated while compiling python/paddle/v2/framework/core.so +CMakeFiles +cmake_install.cmake + diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 0ecb1242c3..48351ab6d0 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -8,6 +8,7 @@ import ( "time" "github.com/namsral/flag" + "github.com/topicai/candy" "github.com/PaddlePaddle/Paddle/go/pserver" log "github.com/sirupsen/logrus" @@ -18,53 +19,47 @@ func main() { index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") - etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") + etcdTimeout := flag.Duration("etcd-timeout", 5*time.Second, "timeout for etcd calls") numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") checkpointPath := flag.String("checkpoint-path", "/checkpoints/", "save checkpoint path") - checkpointInterval := flag.Int("checkpoint-interval", 600, "save checkpoint per interval seconds") + checkpointInterval := flag.Duration("checkpoint-interval", 600*time.Second, "save checkpoint per interval seconds") logLevel := flag.String("log-level", "info", "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() level, err := log.ParseLevel(*logLevel) - if err != nil { - panic(err) - } + candy.Must(err) + log.SetLevel(level) var idx int - var cp pserver.Checkpoint + + var cp *pserver.Checkpoint var e *pserver.EtcdClient if *index >= 0 { idx = *index } else { - timeout := time.Second * time.Duration((*etcdTimeout)) - e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + e = pserver.NewEtcdClient(*etcdEndpoint, *numPservers, *etcdTimeout) idx, err = e.Register() + candy.Must(err) + + cp, err = pserver.NewCheckpointFromFile(*checkpointPath, idx, e) if err != nil { - panic(err) + log.Errorf("Fetch checkpoint failed, %s", err) } } s, err := pserver.NewService(idx, *checkpointInterval, *checkpointPath, e, cp) - if err != nil { - panic(err) - } + candy.Must(err) + err = rpc.Register(s) - if err != nil { - panic(err) - } + candy.Must(err) rpc.HandleHTTP() l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) - if err != nil { - panic(err) - } + candy.Must(err) log.Infof("start pserver at port %d", *port) err = http.Serve(l, nil) - - if err != nil { - panic(err) - } + candy.Must(err) } diff --git a/go/glide.lock b/go/glide.lock index 190a222338..f71ae643d6 100644 --- a/go/glide.lock +++ b/go/glide.lock @@ -1,8 +1,8 @@ -hash: b8f18ce6784bd3fadd9fed0b8443e7b658234ea785ae1f220723ae2c1f652aa7 -updated: 2017-06-27T14:05:48.925262819+08:00 +hash: a8faea3a363468a88917ddeb3b1c9ea36886fb2c622acbad42604fa9cb4d3855 +updated: 2017-07-11T10:04:40.786745417+08:00 imports: - name: github.com/coreos/etcd - version: 61fc123e7a8b14a0a258aa3f5c4159861b1ec2e7 + version: cb2a496c4ddd1c87a9f280e116649b599999ec79 subpackages: - auth/authpb - clientv3 @@ -22,7 +22,9 @@ imports: - name: github.com/PaddlePaddle/recordio version: edfb82af0739c84f241c87390ec5649c7b28c129 - name: github.com/sirupsen/logrus - version: 202f25545ea4cf9b191ff7f846df5d87c9382c2b + version: 7f976d3a76720c4c27af2ba716b85d2e0a7e38b1 +- name: github.com/topicai/candy + version: 1b9030d056fa9f8c4b1f9c91b52fe4b8ab4cd8cc - name: golang.org/x/net version: c8c74377599bd978aee1cf3b9b63a8634051cec2 subpackages: @@ -34,11 +36,11 @@ imports: - lex/httplex - trace - name: golang.org/x/sys - version: f7928cfef4d09d1b080aa2b6fd3ca9ba1567c733 + version: abf9c25f54453410d0c6668e519582a9e1115027 subpackages: - unix - name: golang.org/x/text - version: 4e9ab9ee170f2a39bd66c92b3e0a47ff47a4bc77 + version: cfdf022e86b4ecfb646e1efbd7db175dd623a8fa subpackages: - secure/bidirule - transform diff --git a/go/glide.yaml b/go/glide.yaml index 05c5d15ca2..ab472c7cda 100644 --- a/go/glide.yaml +++ b/go/glide.yaml @@ -10,3 +10,4 @@ import: version: ^1.7.4-pre - package: github.com/sirupsen/logrus version: ^1.0.0 +- package: github.com/topicai/candy diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 1f77787150..4a694b97f4 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -16,7 +16,7 @@ import ( const ( // PsDesired is etcd path for store desired pserver count PsDesired = "/ps_desired" - // PsAddr is the base dir for pserver to store their addr + // PsPath is the base dir for pserver to store their addr PsPath = "/ps/" // PsCheckpoint is the etcd path for store checkpoints information PsCheckpoint = "/checkpoints/" @@ -189,9 +189,25 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { return idx, nil } +// GetKey gets the value by the specified key +func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + resp, err := e.etcdClient.Get(ctx, key) + cancel() + if err != nil { + return []byte{}, err + } + kvs := resp.Kvs + if len(kvs) == 0 { + return []byte{}, nil + } + v := kvs[0].Value + return v, nil +} + // PutKey put into etcd with value by key specified -func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout)) +func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(context.Background(), timeout) _, err := e.etcdClient.Put(ctx, key, string(value)) cancel() if err != nil { diff --git a/go/pserver/service.go b/go/pserver/service.go index 6b52d0d896..65db6970a7 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -9,6 +9,7 @@ import ( "encoding/json" "errors" "fmt" + "io/ioutil" "os" "path/filepath" "strconv" @@ -21,14 +22,14 @@ import ( // ElementType is the type of elements of a Parameter. type ElementType int +// RPC error message. const ( - // AlreadyInitialized is true if pserver is initialized - AlreadyInitialized = "pserver already initialized" - // Uninitialized is true if pserver not fully initialized - Uninitialized = "pserver not fully initialized" + AlreadyInitialized = "pserver already initialized" + Uninitialized = "pserver not fully initialized" + CheckpointMD5Failed = "checkpoint file MD5 validation failed" ) -// Supported element types +// Supported element types. const ( Int32 ElementType = iota UInt32 @@ -51,21 +52,15 @@ type ParameterWithConfig struct { Config []byte // parameter configuration in Proto Buffer format } -// ParameterCheckpoint is Parameter and State checkpoint -type ParameterCheckpoint struct { - ParamConfig ParameterWithConfig - State []byte -} - -// checkpoint signature +// checkpointMeta saves checkpoint metadata type checkpointMeta struct { UUID string `json:"uuid"` - Md5sum string `json:"md5sum"` - Timestamp string `json:"timestamp"` + MD5 string `json:"md5"` + Timestamp int64 `json:"timestamp"` } // Checkpoint is the pserver shard persist in file -type Checkpoint []ParameterCheckpoint +type Checkpoint []parameterCheckpoint // Gradient is the gradient of the parameter. type Gradient Parameter @@ -81,12 +76,53 @@ type Service struct { optMap map[string]*optimizer } +// parameterCheckpoint saves parameter checkpoint +type parameterCheckpoint struct { + ParameterWithConfig + State []byte +} + +// NewCheckpointFromFile loads parameters and state from checkpoint file +func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) { + v, err := e.GetKey(PsPath+string(idx), 3*time.Second) + if err != nil { + return nil, err + } + + var cpMeta checkpointMeta + if err = json.Unmarshal(v, &cpMeta); err != nil { + return nil, err + } + + fn := filepath.Join(cpPath, cpMeta.UUID) + if _, err = os.Stat(fn); os.IsNotExist(err) { + return nil, err + } + content, err := ioutil.ReadFile(fn) + if err != nil { + return nil, err + } + + h := md5.New() + md5 := hex.EncodeToString(h.Sum(content)) + if md5 != cpMeta.MD5 { + return nil, errors.New(CheckpointMD5Failed) + } + + dec := gob.NewDecoder(bytes.NewReader(content)) + cp := &Checkpoint{} + if err = dec.Decode(cp); err != nil { + return nil, err + } + return cp, nil +} + // NewService creates a new service, will bypass etcd registration if no -// endpoints specified. -func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) { +// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint. +func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) { s := &Service{ idx: idx, - checkpointInterval: time.Second * time.Duration(seconds), + checkpointInterval: interval, checkpointPath: path, client: client, } @@ -94,10 +130,12 @@ func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkp s.initialized = make(chan struct{}) if cp != nil { - for _, item := range cp { - p := item.ParamConfig - st := item.State - s.optMap[p.Param.Name] = newOptimizer(p, st) + for _, item := range *cp { + p := ParameterWithConfig{ + Param: item.Param, + Config: item.Config, + } + s.optMap[p.Param.Name] = newOptimizer(p, item.State) } } return s, nil @@ -186,13 +224,13 @@ func (s *Service) doCheckpoint() error { s.mu.Lock() defer s.mu.Unlock() - cp := make([]ParameterCheckpoint, 0, len(s.optMap)) + cp := make([]parameterCheckpoint, len(s.optMap)) index := 0 for name, opt := range s.optMap { - var pc ParameterCheckpoint - pc.ParamConfig.Param.Name = name - pc.ParamConfig.Param.ElementType = opt.elementType - pc.ParamConfig.Param.Content = opt.GetWeights() + var pc parameterCheckpoint + pc.Param.Name = name + pc.Param.ElementType = opt.elementType + pc.Param.Content = opt.GetWeights() pc.State = opt.GetStates() cp[index] = pc index++ @@ -206,12 +244,12 @@ func (s *Service) doCheckpoint() error { cpMeta := checkpointMeta{} cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx) - cpMeta.Timestamp = time.Now().String() + cpMeta.Timestamp = time.Now().UnixNano() h := md5.New() - cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes())) + cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes())) cpMetajson, _ := json.Marshal(cpMeta) - err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3) + err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second) if err != nil { return err } @@ -219,7 +257,11 @@ func (s *Service) doCheckpoint() error { log.Info("checkpoint does not exists.") } else { err = os.Remove(cpMeta.UUID) - log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) + if err != nil { + log.Infof("Removing checkpoint %s failed", cpMeta.UUID) + } else { + log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID) + } } f, err := os.Create(cpMeta.UUID) defer f.Close() -- GitLab