提交 a40a7a5c 编写于 作者: G gongweibao

fix by helin's comments

上级 578dd090
...@@ -62,7 +62,7 @@ func (c *Client) getRecords() { ...@@ -62,7 +62,7 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data // We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly // instance of the task is read. This is not exactly
// correct, but a reasonable approximation. // correct, but a reasonable approximation.
c.taskFinished(t.ID) c.taskFinished(t.Meta.ID)
} }
} }
...@@ -113,8 +113,8 @@ func (c *Client) taskFinished(taskID int) error { ...@@ -113,8 +113,8 @@ func (c *Client) taskFinished(taskID int) error {
} }
// TaskFailed tell the master server as task is failed. // TaskFailed tell the master server as task is failed.
func (c *Client) taskFailed(taskID TaskID) error { func (c *Client) taskFailed(meta TaskMeta) error {
return c.conn.Call("Service.TaskFinished", taskID, nil) return c.conn.Call("Service.TaskFinished", meta, nil)
} }
// NextRecord returns next record in the dataset. // NextRecord returns next record in the dataset.
......
...@@ -95,7 +95,7 @@ func TestGetFinishTask(t *testing.T) { ...@@ -95,7 +95,7 @@ func TestGetFinishTask(t *testing.T) {
t.Fatalf("Should get error, pass: %d\n", i) t.Fatalf("Should get error, pass: %d\n", i)
} }
err = c.taskFinished(tasks[0].ID) err = c.taskFinished(tasks[0].Meta.ID)
if err != nil { if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", err, i)
} }
...@@ -107,7 +107,7 @@ func TestGetFinishTask(t *testing.T) { ...@@ -107,7 +107,7 @@ func TestGetFinishTask(t *testing.T) {
tasks = append(tasks, task) tasks = append(tasks, task)
for _, task := range tasks { for _, task := range tasks {
err = c.taskFinished(task.ID) err = c.taskFinished(task.Meta.ID)
if err != nil { if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", err, i)
} }
......
...@@ -31,10 +31,15 @@ type Chunk struct { ...@@ -31,10 +31,15 @@ type Chunk struct {
Index recordio.Index // chunk index Index recordio.Index // chunk index
} }
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
ID int
Epoch int
}
// Task is the basic unit of data instances assigned to trainers. // Task is the basic unit of data instances assigned to trainers.
type Task struct { type Task struct {
ID int Meta TaskMeta
Epoch int
Chunks []Chunk Chunks []Chunk
} }
...@@ -74,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -74,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
var cur taskEntry var cur taskEntry
for i, c := range chunks { for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.Meta.ID = id
id++ id++
result = append(result, cur) result = append(result, cur)
cur.Task.Chunks = nil cur.Task.Chunks = nil
...@@ -84,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -84,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
if len(cur.Task.Chunks) > 0 { if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.Meta.ID = id
result = append(result, cur) result = append(result, cur)
} }
...@@ -258,8 +263,8 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { ...@@ -258,8 +263,8 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return nil return nil
} }
func (s *Service) procFailedTask(t taskEntry, epoch int) { func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Epoch != epoch { if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the // new epoch, task launched after the
// schedule of this timeout check or failed status report. // schedule of this timeout check or failed status report.
return return
...@@ -272,7 +277,7 @@ func (s *Service) procFailedTask(t taskEntry, epoch int) { ...@@ -272,7 +277,7 @@ func (s *Service) procFailedTask(t taskEntry, epoch int) {
} }
}() }()
delete(s.taskQueues.Pending, t.Task.ID) delete(s.taskQueues.Pending, t.Task.Meta.ID)
t.NumFailure++ t.NumFailure++
if t.NumFailure > s.failureMax { if t.NumFailure > s.failureMax {
...@@ -296,7 +301,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -296,7 +301,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return return
} }
s.procFailedTask(t, epoch) s.processFailedTask(t, epoch)
} }
} }
...@@ -345,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -345,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error {
} }
t := s.taskQueues.Todo[0] t := s.taskQueues.Todo[0]
t.Task.Epoch++ t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:] s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t s.taskQueues.Pending[t.Task.Meta.ID] = t
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
return err return err
} }
*task = t.Task *task = t.Task
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t) log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Meta)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Task.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil return nil
} }
...@@ -373,7 +378,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -373,7 +378,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
if !ok { if !ok {
err := errors.New("pending task not found") err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err return nil
} }
// task finished, reset timeout // task finished, reset timeout
...@@ -396,14 +401,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -396,14 +401,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
return err return err
} }
// TaskID is a struct which client uses for reports failure.
type TaskID struct {
ID int
Epoch int
}
// TaskFailed tells the service that a task is failed. // TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(taskID TaskID, dummy *int) error { func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select { select {
case <-s.ready: case <-s.ready:
} }
...@@ -411,13 +410,13 @@ func (s *Service) TaskFailed(taskID TaskID, dummy *int) error { ...@@ -411,13 +410,13 @@ func (s *Service) TaskFailed(taskID TaskID, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID.ID] t, ok := s.taskQueues.Pending[meta.ID]
if !ok { if !ok {
err := errors.New("pending task not found") err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", taskID) log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Meta)
return err return nil
} }
s.procFailedTask(t, taskID.Epoch) s.processFailedTask(t, meta.Epoch)
return nil return nil
} }
...@@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) { ...@@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100) cs := make([]Chunk, 100)
ts := partition(cs, 20) ts := partition(cs, 20)
for i := range ts { for i := range ts {
if ts[i].Task.ID != i { if ts[i].Task.Meta.ID != i {
t.Error(ts[i], i) t.Error(ts[i], i)
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册