From fd8937556f95db4086ce095efa1e83041c896334 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 20 Jun 2017 23:57:07 +0000 Subject: [PATCH] Master save and load state from etcd --- go/cmd/master/master.go | 55 ++++++++++-- go/master/client_internal_test.go | 21 ++++- go/master/client_test.go | 21 ++++- go/master/etcd_store.go | 133 ++++++++++++++++++++++++++++ go/master/service.go | 142 +++++++++++++++++++++++------- go/pserver/cclient/cclient.go | 6 +- 6 files changed, 330 insertions(+), 48 deletions(-) create mode 100644 go/master/etcd_store.go diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 25cd1cafcd..49ad0300b8 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -5,41 +5,80 @@ import ( "net/http" "net/rpc" "strconv" + "strings" + "sync" "time" "github.com/namsral/flag" + log "github.com/sirupsen/logrus" "github.com/PaddlePaddle/Paddle/go/master" ) +type inMemStore struct { + mu sync.Mutex + buf []byte +} + +func (m *inMemStore) Save(b []byte) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.buf = b + return nil +} + +func (m *inMemStore) Load() ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + return m.buf, nil +} + func main() { port := flag.Int("port", 8080, "port of the master server.") - faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") + ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") + endpoints := flag.String("endpoints", "", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") flag.Parse() - if *faultTolerance { - panic("fault tolernance not implemented.") + if *endpoints == "" { + log.Warningln("-endpoints not set, fault tolerance not be enabled.") + } + + var store master.Store + if *endpoints != "" { + eps := strings.Split(*endpoints, ",") + var err error + store, err = master.NewEtcdStore(eps, master.DefaultLockPath, master.DefaultStatePath, *ttlSec) + if err != nil { + log.Fatal(err) + } + } else { + store = &inMemStore{} + } + s, err := master.NewService(store, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) + if err != nil { + log.Fatal(err) } - s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) - err := rpc.Register(s) + err = rpc.Register(s) if err != nil { - panic(err) + log.Fatal(err) } rpc.HandleHTTP() l, err := net.Listen("tcp", ":"+strconv.Itoa(*port)) if err != nil { - panic(err) + log.Fatal(err) } err = http.Serve(l, nil) if err != nil { - panic(err) + log.Fatal(err) } } diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index 00fcca0e2c..a5b76fe853 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -32,6 +32,19 @@ func (a TestAddresser) Address() string { return string(a) } +type myStore struct { + buf []byte +} + +func (m *myStore) Save(b []byte) error { + m.buf = b + return nil +} + +func (m *myStore) Load() ([]byte, error) { + return m.buf, nil +} + func TestGetFinishTask(t *testing.T) { const path = "/tmp/master_client_test_0" @@ -47,9 +60,13 @@ func TestGetFinishTask(t *testing.T) { } go func(l net.Listener) { - s := NewService(chunkPerTask, time.Second, 1) + s, err := NewService(&myStore{}, chunkPerTask, time.Second, 1) + if err != nil { + panic(err) + } + server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/master/client_test.go b/go/master/client_test.go index 2b3f873ecf..ae5f17c2d4 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -15,6 +15,19 @@ import ( "github.com/PaddlePaddle/recordio" ) +type myStore struct { + buf []byte +} + +func (m *myStore) Save(b []byte) error { + m.buf = b + return nil +} + +func (m *myStore) Load() ([]byte, error) { + return m.buf, nil +} + func TestNextRecord(t *testing.T) { const ( path = "/tmp/master_client_TestFull" @@ -33,9 +46,13 @@ func TestNextRecord(t *testing.T) { } go func(l net.Listener) { - s := master.NewService(10, time.Second, 1) + s, err := master.NewService(&myStore{}, 10, time.Second, 1) + if err != nil { + panic(err) + } + server := rpc.NewServer() - err := server.Register(s) + err = server.Register(s) if err != nil { panic(err) } diff --git a/go/master/etcd_store.go b/go/master/etcd_store.go new file mode 100644 index 0000000000..ce178370ff --- /dev/null +++ b/go/master/etcd_store.go @@ -0,0 +1,133 @@ +package master + +import ( + "context" + "sync" + + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" +) + +const ( + // DefaultLockPath is the default etcd master lock path. + DefaultLockPath = "/master/lock" + // DefaultStatePath is the default etcd key for master state. + DefaultStatePath = "/master/state" +) + +// EtcdStore is the Store implementation backed by etcd. +type EtcdStore struct { + lockPath string + statePath string + ttlSec int + client *clientv3.Client + + mu sync.Mutex + lock *concurrency.Mutex +} + +// NewEtcdStore creates a new EtcdStore. +func NewEtcdStore(endpoints []string, lockPath, statePath string, ttlSec int) (*EtcdStore, error) { + cli, err := clientv3.New(clientv3.Config{ + Endpoints: endpoints, + DialTimeout: dialTimeout, + }) + if err != nil { + return nil, err + } + + sess, err := concurrency.NewSession(cli, concurrency.WithTTL(ttlSec)) + if err != nil { + return nil, err + } + + lock := concurrency.NewMutex(sess, lockPath) + // It's fine for the lock to get stuck, in this case we have + // multiple master servers running (only configured to have + // one master running, but split-brain problem may cuase + // multiple master servers running), and the cluster management + // software will kill one of them. + log.Infof("Trying to acquire lock at %s.", lockPath) + err = lock.Lock(context.TODO()) + if err != nil { + return nil, err + } + log.Infof("Successfully acquired lock at %s.", lockPath) + + e := &EtcdStore{} + e.client = cli + e.lock = lock + e.lockPath = lockPath + e.statePath = statePath + e.ttlSec = ttlSec + return e, nil +} + +// Save saves the state into the etcd. +func (e *EtcdStore) Save(state []byte) error { + e.mu.Lock() + defer e.mu.Unlock() + + ctx := context.TODO() + put := clientv3.OpPut(e.statePath, string(state)) + resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit() + if err != nil { + return err + } + + if !resp.Succeeded { + log.Errorln("No longer owns the lock, trying to lock and save again.") + sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(e.ttlSec)) + if err != nil { + return err + } + + e.lock = concurrency.NewMutex(sess, e.lockPath) + log.Infof("Try to acquire lock at %s.", e.lockPath) + err = e.lock.Lock(context.TODO()) + if err != nil { + return err + } + log.Infof("Successfully acquired lock at %s.", e.lockPath) + return e.Save(state) + } + + return nil +} + +// Load loads the state from etcd. +func (e *EtcdStore) Load() ([]byte, error) { + e.mu.Lock() + ctx := context.TODO() + get := clientv3.OpGet(e.statePath) + + resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(get).Commit() + if err != nil { + return nil, err + } + + if !resp.Succeeded { + log.Errorln("No longer owns the lock, trying to lock and load again.") + sess, err := concurrency.NewSession(e.client) + if err != nil { + return nil, err + } + + e.lock = concurrency.NewMutex(sess, e.lockPath) + e.lock.Lock(context.TODO()) + e.mu.Unlock() + return e.Load() + } + + kvs := resp.Responses[0].GetResponseRange().Kvs + if len(kvs) == 0 { + // No state exists + e.mu.Unlock() + return nil, nil + } + + state := kvs[0].Value + e.mu.Unlock() + return state, nil +} diff --git a/go/master/service.go b/go/master/service.go index 55e1e2d1a4..d453777b05 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -1,6 +1,9 @@ package master import ( + "bytes" + "compress/gzip" + "encoding/gob" "errors" "os" "path/filepath" @@ -12,24 +15,54 @@ import ( "github.com/PaddlePaddle/recordio" ) +const ( + dialTimeout = 5 * time.Second +) + +// Store is the interface for save and load the master state. +type Store interface { + Save([]byte) error + Load() ([]byte, error) +} + +// Chunk is a chunk of data consisted of several data instances. +type Chunk struct { + Path string + Index recordio.Index // chunk index +} + +// Task is the basic unit of data instances assigned to trainers. +type Task struct { + ID int + Chunks []Chunk +} + +type taskEntry struct { + Epoch int + NumTimeout int + Task Task +} + +type taskQueues struct { + Todo []taskEntry + Pending map[int]taskEntry // map from task ID to task entry + Done []taskEntry + Failed []Task +} + // Service is the master server service. type Service struct { chunksPerTask int timeoutDur time.Duration timeoutMax int ready chan struct{} + store Store mu sync.Mutex initDone bool taskQueues taskQueues } -// Recover recovers service state from etcd. -func Recover() (*Service, error) { - // TODO(helin): recover from snapshot state from etcd. - return nil, nil -} - func partition(chunks []Chunk, chunksPerTask int) []taskEntry { id := 0 if chunksPerTask <= 0 { @@ -58,7 +91,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { } // NewService creates a new service. -func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { +func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) { s := &Service{} s.chunksPerTask = chunksPerTask s.timeoutDur = timeoutDur @@ -66,38 +99,81 @@ func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Se s.taskQueues = taskQueues{} s.taskQueues.Pending = make(map[int]taskEntry) s.ready = make(chan struct{}) - return s -} + s.store = store + recovered, err := s.recover() + if err != nil { + return nil, err + } -// Chunk is a chunk of data consisted of several data instances. -type Chunk struct { - Path string - Index recordio.Index // chunk index -} + if recovered { + // Recovered. Now the state is already initialized, + // and the master is ready. + s.initDone = true + close(s.ready) + } -// Task is the basic unit of data instances assigned to trainers. -type Task struct { - ID int - Chunks []Chunk + return s, nil } -type taskEntry struct { - Epoch int - NumTimeout int - Task Task -} +// recover recovers service state from etcd. +func (s *Service) recover() (bool, error) { + state, err := s.store.Load() + if err != nil { + return false, err + } -type taskQueues struct { - Todo []taskEntry - Pending map[int]taskEntry // map from task ID to task entry - Done []taskEntry - Failed []Task + if state == nil { + log.Infoln("No state exists, not recovered.") + return false, nil + } + + log.Infof("Loaded snapshot of size: %d bytes.", len(state)) + gr, err := gzip.NewReader(bytes.NewReader(state)) + if err != nil { + return false, err + } + + dec := gob.NewDecoder(gr) + var tqs taskQueues + err = dec.Decode(&tqs) + if err != nil { + return false, err + } + + err = gr.Close() + if err != nil { + // Only close failed, recover actually succeed, so + // just log error. + log.Errorln(err) + } + + s.taskQueues = tqs + return true, nil } -// *must* be called with s.mu being held. +// snapshot *must* be called with s.mu being held. func (s *Service) snapshot() error { - // TODO(helin): snapshot state on etcd. - return nil + // TOOD(helin): etcd request has a size limit, so the snapshot + // size is limited by the max request size. We should either + // divide the snapshot into smaller chunks and save under + // different keys, or configure the request size to be big + // enough: + // https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44 + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + enc := gob.NewEncoder(gw) + err := enc.Encode(s.taskQueues) + if err != nil { + return err + } + err = gw.Close() + if err != nil { + return err + } + + state := buf.Bytes() + log.Infof("Saving snapshot of size: %d bytes.", len(state)) + return s.store.Save(state) } func readChunks(globPaths []string) ([]Chunk, error) { @@ -207,12 +283,12 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { t.NumTimeout++ if t.NumTimeout > s.timeoutMax { - log.Warningf("Task %v timed out %d times, discard.\n", t.Task, t.NumTimeout) + log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) return } - log.Warningf("Task %v timed out %d times, retry.\n", t.Task, t.NumTimeout) + log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout) s.taskQueues.Todo = append(s.taskQueues.Todo, t) } } diff --git a/go/pserver/cclient/cclient.go b/go/pserver/cclient/cclient.go index 92a41b7f54..bbaf43d9f1 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/cclient/cclient.go @@ -133,7 +133,7 @@ func paddle_init_param(client C.paddle_pserver_client, param C.paddle_parameter, if err != nil { if err.Error() == pserver.AlreadyInitialized { - log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.\n", name) + log.Warningf("parameter %s already initialized, treat paddle_init_param as sucessful.", name) return C.PSERVER_OK } log.Errorln(err) @@ -200,7 +200,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, for i, p := range ps { pn[i] = p.Name } - log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + log.Errorf("pserver returned wrong number of parameters. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) return C.PSERVER_ERROR } @@ -210,7 +210,7 @@ func paddle_get_params(client C.paddle_pserver_client, dst **C.paddle_parameter, for i, p := range ps { pn[i] = p.Name } - log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.\n", strings.Join(pn, ", "), strings.Join(ns, ", ")) + log.Errorf("pserver returned wrong parameters, or not in requested order. Requested: %s, returned: %s.", strings.Join(pn, ", "), strings.Join(ns, ", ")) return C.PSERVER_ERROR } } -- GitLab