diff --git a/go/master/service.go b/go/master/service.go index d30e9a33229c0aff354417771b5bf2ae6a781715..f072dd786c966886b068f9afe9bca3e63fb6bb5b 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -77,11 +77,13 @@ type taskEntry struct { NumFailure int } -type taskQueues struct { - Todo []taskEntry - Pending map[int]taskEntry // map from task ID to task entry - Done []taskEntry - Failed []taskEntry +type masterState struct { + Todo []taskEntry + Pending map[int]taskEntry // map from task ID to task entry + Done []taskEntry + Failed []taskEntry + CurPass int + JobTasks []taskEntry } // Service is the master server service. @@ -94,11 +96,11 @@ type Service struct { ready chan struct{} initDone bool - mu sync.Mutex - taskQueues taskQueues - currPass int - jobTasks []taskEntry - + mu sync.Mutex + // State to be persisted to snapshot. + state masterState + // The trainer that is currently saving model. This state is + // transient, does not need to be persisted to snapshot. savingTrainer string } @@ -141,8 +143,8 @@ func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failur s.chunksPerTask = chunksPerTask s.timeoutDur = timeoutDur s.failureMax = failureMax - s.taskQueues = taskQueues{} - s.taskQueues.Pending = make(map[int]taskEntry) + s.state = masterState{} + s.state.Pending = make(map[int]taskEntry) s.ready = make(chan struct{}) s.store = store recovered, err := s.recover() @@ -180,7 +182,7 @@ func (s *Service) recover() (bool, error) { } dec := gob.NewDecoder(gr) - var tqs taskQueues + var tqs masterState err = dec.Decode(&tqs) if err != nil { return false, err @@ -193,7 +195,12 @@ func (s *Service) recover() (bool, error) { log.Errorln(err) } - s.taskQueues = tqs + s.state = tqs + log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.") + for _, t := range s.state.Pending { + time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) + } + return true, nil } @@ -208,7 +215,7 @@ func (s *Service) snapshot() error { var buf bytes.Buffer gw := gzip.NewWriter(&buf) enc := gob.NewEncoder(gw) - err := enc.Encode(s.taskQueues) + err := enc.Encode(s.state) if err != nil { return err } @@ -290,8 +297,8 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error { return err } - s.jobTasks = partition(chunks, s.chunksPerTask) - s.taskQueues.Todo = s.jobTasks + s.state.JobTasks = partition(chunks, s.chunksPerTask) + s.state.Todo = s.state.JobTasks err = s.snapshot() if err != nil { @@ -319,17 +326,17 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { } }() - delete(s.taskQueues.Pending, t.Task.Meta.ID) + delete(s.state.Pending, t.Task.Meta.ID) t.NumFailure++ if t.NumFailure > s.failureMax { log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) - s.taskQueues.Failed = append(s.taskQueues.Failed, t) + s.state.Failed = append(s.state.Failed, t) return } log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure) - s.taskQueues.Todo = append(s.taskQueues.Todo, t) + s.state.Todo = append(s.state.Todo, t) return } @@ -338,7 +345,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { s.mu.Lock() defer s.mu.Unlock() - t, ok := s.taskQueues.Pending[taskID] + t, ok := s.state.Pending[taskID] if !ok { return } @@ -350,10 +357,10 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { // must be called with lock held. func (s *Service) logFields() log.Fields { return log.Fields{ - "todoLen": len(s.taskQueues.Todo), - "pendingLen": len(s.taskQueues.Pending), - "doneLen": len(s.taskQueues.Done), - "failedLen": len(s.taskQueues.Failed), + "todoLen": len(s.state.Todo), + "pendingLen": len(s.state.Pending), + "doneLen": len(s.state.Done), + "failedLen": len(s.state.Failed), } } @@ -366,17 +373,17 @@ func (s *Service) GetTask(passID int, task *Task) error { s.mu.Lock() defer s.mu.Unlock() - if passID < s.currPass { + if passID < s.state.CurPass { return ErrPassBefore } - if passID > s.currPass { + if passID > s.state.CurPass { // Client may get run to pass after master when one client faster than the // other return ErrPassAfter } - if len(s.taskQueues.Todo) == 0 { - if len(s.taskQueues.Done) == 0 && len(s.taskQueues.Pending) == 0 { + if len(s.state.Todo) == 0 { + if len(s.state.Done) == 0 && len(s.state.Pending) == 0 { log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass") return ErrAllTaskFailed } @@ -384,10 +391,10 @@ func (s *Service) GetTask(passID int, task *Task) error { return ErrNoMoreAvailable } - t := s.taskQueues.Todo[0] + t := s.state.Todo[0] t.Task.Meta.Epoch++ - s.taskQueues.Todo = s.taskQueues.Todo[1:] - s.taskQueues.Pending[t.Task.Meta.ID] = t + s.state.Todo = s.state.Todo[1:] + s.state.Pending[t.Task.Meta.ID] = t err := s.snapshot() if err != nil { return err @@ -409,7 +416,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { s.mu.Lock() defer s.mu.Unlock() - t, ok := s.taskQueues.Pending[taskID] + t, ok := s.state.Pending[taskID] if !ok { log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) return nil @@ -417,18 +424,18 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { // task finished, reset timeout t.NumFailure = 0 - s.taskQueues.Done = append(s.taskQueues.Done, t) - delete(s.taskQueues.Pending, taskID) + s.state.Done = append(s.state.Done, t) + delete(s.state.Pending, taskID) log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) - if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 { + if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 { // increase master side pass count if all tasks finished - s.currPass++ - s.taskQueues.Todo = s.jobTasks - s.taskQueues.Done = []taskEntry{} + s.state.CurPass++ + s.state.Todo = s.state.JobTasks + s.state.Done = []taskEntry{} // TODO(typhoonzero): deal with failed tasks - s.taskQueues.Failed = []taskEntry{} - log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.currPass) + s.state.Failed = []taskEntry{} + log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass) } err := s.snapshot() @@ -447,7 +454,7 @@ func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { s.mu.Lock() defer s.mu.Unlock() - t, ok := s.taskQueues.Pending[meta.ID] + t, ok := s.state.Pending[meta.ID] if !ok { log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta) return nil diff --git a/go/pserver/client/etcd_client.go b/go/pserver/client/etcd_client.go index b6ff1fec8a6f37f61f38cb5d004b1d2c886473ed..977ae5af37e2b7d647ae16af9c4403f916b0216d 100644 --- a/go/pserver/client/etcd_client.go +++ b/go/pserver/client/etcd_client.go @@ -103,7 +103,7 @@ func (p *EtcdClient) List() []Server { time.Sleep(p.timeout) continue } - log.Infof("got value (%s) for key: %s", psAddr, psKey) + log.Debugf("got value (%s) for key: %s", psAddr, psKey) servers[i].Index = i servers[i].Addr = psAddr }