From a40a7a5cb1a80f5489800dd6cda329667ac47c4d Mon Sep 17 00:00:00 2001 From: gongweibao Date: Tue, 11 Jul 2017 10:25:30 +0800 Subject: [PATCH] fix by helin's comments --- go/master/client.go | 6 ++-- go/master/client_internal_test.go | 4 +-- go/master/service.go | 47 +++++++++++++++--------------- go/master/service_internal_test.go | 2 +- 4 files changed, 29 insertions(+), 30 deletions(-) diff --git a/go/master/client.go b/go/master/client.go index bf2612d91..6f06fd042 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -62,7 +62,7 @@ func (c *Client) getRecords() { // We treat a task as finished whenever the last data // instance of the task is read. This is not exactly // correct, but a reasonable approximation. - c.taskFinished(t.ID) + c.taskFinished(t.Meta.ID) } } @@ -113,8 +113,8 @@ func (c *Client) taskFinished(taskID int) error { } // TaskFailed tell the master server as task is failed. -func (c *Client) taskFailed(taskID TaskID) error { - return c.conn.Call("Service.TaskFinished", taskID, nil) +func (c *Client) taskFailed(meta TaskMeta) error { + return c.conn.Call("Service.TaskFinished", meta, nil) } // NextRecord returns next record in the dataset. diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index 364dce7b5..dc4d9eab1 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -95,7 +95,7 @@ func TestGetFinishTask(t *testing.T) { t.Fatalf("Should get error, pass: %d\n", i) } - err = c.taskFinished(tasks[0].ID) + err = c.taskFinished(tasks[0].Meta.ID) if err != nil { t.Fatalf("Error: %v, pass: %d\n", err, i) } @@ -107,7 +107,7 @@ func TestGetFinishTask(t *testing.T) { tasks = append(tasks, task) for _, task := range tasks { - err = c.taskFinished(task.ID) + err = c.taskFinished(task.Meta.ID) if err != nil { t.Fatalf("Error: %v, pass: %d\n", err, i) } diff --git a/go/master/service.go b/go/master/service.go index daf392823..1291ac48f 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -31,10 +31,15 @@ type Chunk struct { Index recordio.Index // chunk index } +// TaskMeta is a struct which stores task's meta info. +type TaskMeta struct { + ID int + Epoch int +} + // Task is the basic unit of data instances assigned to trainers. type Task struct { - ID int - Epoch int + Meta TaskMeta Chunks []Chunk } @@ -74,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { var cur taskEntry for i, c := range chunks { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { - cur.Task.ID = id + cur.Task.Meta.ID = id id++ result = append(result, cur) cur.Task.Chunks = nil @@ -84,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { } if len(cur.Task.Chunks) > 0 { - cur.Task.ID = id + cur.Task.Meta.ID = id result = append(result, cur) } @@ -258,8 +263,8 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { return nil } -func (s *Service) procFailedTask(t taskEntry, epoch int) { - if t.Task.Epoch != epoch { +func (s *Service) processFailedTask(t taskEntry, epoch int) { + if t.Task.Meta.Epoch != epoch { // new epoch, task launched after the // schedule of this timeout check or failed status report. return @@ -272,7 +277,7 @@ func (s *Service) procFailedTask(t taskEntry, epoch int) { } }() - delete(s.taskQueues.Pending, t.Task.ID) + delete(s.taskQueues.Pending, t.Task.Meta.ID) t.NumFailure++ if t.NumFailure > s.failureMax { @@ -296,7 +301,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { return } - s.procFailedTask(t, epoch) + s.processFailedTask(t, epoch) } } @@ -345,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error { } t := s.taskQueues.Todo[0] - t.Task.Epoch++ + t.Task.Meta.Epoch++ s.taskQueues.Todo = s.taskQueues.Todo[1:] - s.taskQueues.Pending[t.Task.ID] = t + s.taskQueues.Pending[t.Task.Meta.ID] = t err := s.snapshot() if err != nil { return err } *task = t.Task - log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t) + log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Meta) - time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Task.Epoch)) + time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch)) return nil } @@ -373,7 +378,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { if !ok { err := errors.New("pending task not found") log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) - return err + return nil } // task finished, reset timeout @@ -396,14 +401,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { return err } -// TaskID is a struct which client uses for reports failure. -type TaskID struct { - ID int - Epoch int -} - // TaskFailed tells the service that a task is failed. -func (s *Service) TaskFailed(taskID TaskID, dummy *int) error { +func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { select { case <-s.ready: } @@ -411,13 +410,13 @@ func (s *Service) TaskFailed(taskID TaskID, dummy *int) error { s.mu.Lock() defer s.mu.Unlock() - t, ok := s.taskQueues.Pending[taskID.ID] + t, ok := s.taskQueues.Pending[meta.ID] if !ok { err := errors.New("pending task not found") - log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", taskID) - return err + log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Meta) + return nil } - s.procFailedTask(t, taskID.Epoch) + s.processFailedTask(t, meta.Epoch) return nil } diff --git a/go/master/service_internal_test.go b/go/master/service_internal_test.go index bc435b505..9c0d1d0a3 100644 --- a/go/master/service_internal_test.go +++ b/go/master/service_internal_test.go @@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) { cs := make([]Chunk, 100) ts := partition(cs, 20) for i := range ts { - if ts[i].Task.ID != i { + if ts[i].Task.Meta.ID != i { t.Error(ts[i], i) } } -- GitLab