diff --git a/go/master/client.go b/go/master/client.go index d3bea49d0a8166420e83478076cc7bc81e48598d..b6ca8cad15a976c5b493871fcd5d34543cda5096 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -112,6 +112,11 @@ func (c *Client) taskFinished(taskID int) error { return c.conn.Call("Service.TaskFinished", taskID, nil) } +// TaskFailed tell the master server as task is failed. +func (c *Client) taskFailed(taskID int, epoch int) error { + return c.conn.Call("Service.TaskFinished", taskID, epoch) +} + // NextRecord returns next record in the dataset. // // NextRecord will block until the next record is available. It is diff --git a/go/master/service.go b/go/master/service.go index 58e68e744859933aa74cac231356d4ff9dfb8d7b..b078f318f57ab78e67835a1f9dcec172b2135ba2 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -34,29 +34,30 @@ type Chunk struct { // Task is the basic unit of data instances assigned to trainers. type Task struct { ID int + Epoch int Chunks []Chunk } type taskEntry struct { - Epoch int NumTimeout int Task Task + FailedNum int } type taskQueues struct { Todo []taskEntry Pending map[int]taskEntry // map from task ID to task entry Done []taskEntry - Failed []Task + Failed []taskEntry } // Service is the master server service. type Service struct { - chunksPerTask int - timeoutDur time.Duration - timeoutMax int - ready chan struct{} - store Store + chunksPerTask int + timeoutDur time.Duration + failortimeoutMax int + ready chan struct{} + store Store mu sync.Mutex initDone bool @@ -91,11 +92,11 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { } // NewService creates a new service. -func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) { +func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failortimeoutMax int) (*Service, error) { s := &Service{} s.chunksPerTask = chunksPerTask s.timeoutDur = timeoutDur - s.timeoutMax = timeoutMax + s.failortimeoutMax = failortimeoutMax s.taskQueues = taskQueues{} s.taskQueues.Pending = make(map[int]taskEntry) s.ready = make(chan struct{}) @@ -257,6 +258,34 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { return nil } +func (s *Service) checkTaskStatus(t taskEntry, epoch int) { + if t.Task.Epoch != epoch { + // new epoch, task launched after the + // schedule of this timeout check or failed status report. + return + } + + defer func() { + err := s.snapshot() + if err != nil { + log.Errorln(err) + } + }() + + delete(s.taskQueues.Pending, t.Task.ID) + + t.NumTimeout++ + if t.NumTimeout+t.FailedNum > s.failortimeoutMax { + log.Warningf("Task %v timed out %d times and failed %d times, discard.", t.Task, t.NumTimeout, t.FailedNum) + s.taskQueues.Failed = append(s.taskQueues.Failed, t) + return + } + + log.Warningf("Task %v timed out %d times and failed %d times, discard.", t.Task, t.NumTimeout, t.FailedNum) + s.taskQueues.Todo = append(s.taskQueues.Todo, t) + return +} + func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { return func() { s.mu.Lock() @@ -267,30 +296,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { return } - if t.Epoch != epoch { - // new epoch, task launched after the - // schedule of this timeout check. - return - } - - defer func() { - err := s.snapshot() - if err != nil { - log.Errorln(err) - } - }() - - delete(s.taskQueues.Pending, t.Task.ID) - - t.NumTimeout++ - if t.NumTimeout > s.timeoutMax { - log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout) - s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) - return - } - - log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout) - s.taskQueues.Todo = append(s.taskQueues.Todo, t) + s.checkTaskStatus(t, epoch) } } @@ -339,7 +345,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { } t := s.taskQueues.Todo[0] - t.Epoch++ + t.Task.Epoch++ s.taskQueues.Todo = s.taskQueues.Todo[1:] s.taskQueues.Pending[t.Task.ID] = t err := s.snapshot() @@ -348,9 +354,9 @@ func (s *Service) GetTask(dummy int, task *Task) error { } *task = t.Task - log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID) + log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t) - time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch)) + time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Task.Epoch)) return nil } @@ -372,6 +378,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { // task finished, reset timeout t.NumTimeout = 0 + t.FailedNum = 0 s.taskQueues.Done = append(s.taskQueues.Done, t) delete(s.taskQueues.Pending, taskID) @@ -389,3 +396,23 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { } return err } + +// TaskFailed tell the service that a task is failed. +func (s *Service) TaskFailed(taskID int, epoch int) error { + select { + case <-s.ready: + } + + s.mu.Lock() + defer s.mu.Unlock() + + t, ok := s.taskQueues.Pending[taskID] + if !ok { + err := errors.New("pending task not found") + log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%d not found.", taskID) + return err + } + + s.checkTaskStatus(t, epoch) + return nil +}