diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index a62bc4310e62e5a90787bf660982b8f52ae34ea6..54fa254863156455f66fa87de9077042a45f9735 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -1,6 +1,7 @@ package main import ( + "fmt" "net" "net/http" "net/rpc" @@ -12,13 +13,13 @@ import ( log "github.com/sirupsen/logrus" "github.com/PaddlePaddle/Paddle/go/master" + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" ) func main() { port := flag.Int("port", 8080, "port of the master server.") - ttlSec := flag.Int("ttl", 60, "etcd lease TTL in seconds.") - endpoints := flag.String("endpoints", "", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") + endpoints := flag.String("endpoints", "http://127.0.0.1:2379", "comma separated etcd endpoints. If empty, fault tolerance will not be enabled.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") @@ -31,8 +32,13 @@ func main() { var store master.Store if *endpoints != "" { eps := strings.Split(*endpoints, ",") - var err error - store, err = master.NewEtcd(eps, master.DefaultLockPath, master.DefaultStatePath, *ttlSec) + ip, err := networkhelper.GetExternalIP() + if err != nil { + log.Fatal(err) + } + + addr := fmt.Sprintf("%s:%d", ip, *port) + store, err = master.NewEtcdClient(eps, addr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, *ttlSec) if err != nil { log.Fatal(err) } diff --git a/go/master/etcd_store.go b/go/master/etcd_client.go similarity index 56% rename from go/master/etcd_store.go rename to go/master/etcd_client.go index 21b3e2cb0f539c4d89266d273cb74c6f93026ddd..b7293a759896f113d630d57d14b4b4ac8963f54a 100644 --- a/go/master/etcd_store.go +++ b/go/master/etcd_client.go @@ -2,7 +2,7 @@ package master import ( "context" - "sync" + "time" "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" @@ -14,22 +14,22 @@ const ( DefaultLockPath = "/master/lock" // DefaultStatePath is the default etcd key for master state. DefaultStatePath = "/master/state" + // DefaultAddrPath is the default etcd key for master address. + DefaultAddrPath = "/master/addr" ) -// Etcd is the etcd abstraction that master uses for fault tolerance +// EtcdClient is the etcd client that master uses for fault tolerance // and service registry. -type Etcd struct { +type EtcdClient struct { lockPath string statePath string - ttlSec int client *clientv3.Client - - mu sync.Mutex - lock *concurrency.Mutex + lock *concurrency.Mutex } -// NewEtcd creates a new Etcd. -func NewEtcd(endpoints []string, lockPath, statePath string, ttlSec int) (*Etcd, error) { +// NewEtcdClient creates a new EtcdClient. +func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePath string, ttlSec int) (*EtcdClient, error) { + log.Debugf("Connecting to etcd at %v", endpoints) // TODO(helin): gracefully shutdown etcd store. Becuase etcd // store holds a etcd lock, even though the lock will expire // when the lease timeout, we need to implement graceful @@ -53,27 +53,35 @@ func NewEtcd(endpoints []string, lockPath, statePath string, ttlSec int) (*Etcd, // one master running, but split-brain problem may cuase // multiple master servers running), and the cluster management // software will kill one of them. - log.Infof("Trying to acquire lock at %s.", lockPath) + log.Debugf("Trying to acquire lock at %s.", lockPath) err = lock.Lock(context.TODO()) if err != nil { return nil, err } - log.Infof("Successfully acquired lock at %s.", lockPath) - - e := &Etcd{} - e.client = cli - e.lock = lock - e.lockPath = lockPath - e.statePath = statePath - e.ttlSec = ttlSec + log.Debugf("Successfully acquired lock at %s.", lockPath) + + put := clientv3.OpPut(addrPath, string(addr)) + resp, err := cli.Txn(context.Background()).If(lock.IsOwner()).Then(put).Commit() + if err != nil { + return nil, err + } + + if !resp.Succeeded { + log.Fatal("No longer owns the master lock. Exiting.") + } + + e := &EtcdClient{ + lockPath: lockPath, + statePath: statePath, + client: cli, + lock: lock, + } + return e, nil } // Save saves the state into the etcd. -func (e *Etcd) Save(state []byte) error { - e.mu.Lock() - defer e.mu.Unlock() - +func (e *EtcdClient) Save(state []byte) error { ctx := context.TODO() put := clientv3.OpPut(e.statePath, string(state)) resp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit() @@ -82,17 +90,21 @@ func (e *Etcd) Save(state []byte) error { } if !resp.Succeeded { - log.Errorln("No longer owns the lock, trying to lock and save again.") - sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(e.ttlSec)) - if err != nil { - return err - } - - e.lock = concurrency.NewMutex(sess, e.lockPath) - log.Infof("Try to acquire lock at %s.", e.lockPath) - err = e.lock.Lock(context.TODO()) + log.Errorln("No longer owns the lock, trying to lock again") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + err := e.lock.Lock(ctx) + cancel() if err != nil { - return err + // We lost the master lock and can not acquire + // it back, it means some other master is + // already started. We don't want cluster + // managment system to kill the master server + // who is holding the lock and running + // correctly. So the most feasible solution is + // to kill current master server. The current + // state is not saved, but the trainer's RPC + // call will fail, so the trainer will retry. + log.Fatalf("Could not acquire the lock at %s: %v. Exiting.", e.lockPath, err) } log.Infof("Successfully acquired lock at %s.", e.lockPath) return e.Save(state) @@ -102,8 +114,7 @@ func (e *Etcd) Save(state []byte) error { } // Load loads the state from etcd. -func (e *Etcd) Load() ([]byte, error) { - e.mu.Lock() +func (e *EtcdClient) Load() ([]byte, error) { ctx := context.TODO() get := clientv3.OpGet(e.statePath) @@ -114,14 +125,7 @@ func (e *Etcd) Load() ([]byte, error) { if !resp.Succeeded { log.Errorln("No longer owns the lock, trying to lock and load again.") - sess, err := concurrency.NewSession(e.client) - if err != nil { - return nil, err - } - - e.lock = concurrency.NewMutex(sess, e.lockPath) - err = e.lock.Lock(context.TODO()) - e.mu.Unlock() + err = e.lock.Lock(context.Background()) if err != nil { return nil, err } @@ -132,11 +136,9 @@ func (e *Etcd) Load() ([]byte, error) { kvs := resp.Responses[0].GetResponseRange().Kvs if len(kvs) == 0 { // No state exists - e.mu.Unlock() return nil, nil } state := kvs[0].Value - e.mu.Unlock() return state, nil }