提交 a40a7a5c 编写于 作者: G gongweibao

fix by helin's comments

上级 578dd090
......@@ -62,7 +62,7 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly
// correct, but a reasonable approximation.
c.taskFinished(t.ID)
c.taskFinished(t.Meta.ID)
}
}
......@@ -113,8 +113,8 @@ func (c *Client) taskFinished(taskID int) error {
}
// TaskFailed tell the master server as task is failed.
func (c *Client) taskFailed(taskID TaskID) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
func (c *Client) taskFailed(meta TaskMeta) error {
return c.conn.Call("Service.TaskFinished", meta, nil)
}
// NextRecord returns next record in the dataset.
......
......@@ -95,7 +95,7 @@ func TestGetFinishTask(t *testing.T) {
t.Fatalf("Should get error, pass: %d\n", i)
}
err = c.taskFinished(tasks[0].ID)
err = c.taskFinished(tasks[0].Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
......@@ -107,7 +107,7 @@ func TestGetFinishTask(t *testing.T) {
tasks = append(tasks, task)
for _, task := range tasks {
err = c.taskFinished(task.ID)
err = c.taskFinished(task.Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
......
......@@ -31,10 +31,15 @@ type Chunk struct {
Index recordio.Index // chunk index
}
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
ID int
Epoch int
}
// Task is the basic unit of data instances assigned to trainers.
type Task struct {
ID int
Epoch int
Meta TaskMeta
Chunks []Chunk
}
......@@ -74,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
var cur taskEntry
for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id
cur.Task.Meta.ID = id
id++
result = append(result, cur)
cur.Task.Chunks = nil
......@@ -84,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
}
if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id
cur.Task.Meta.ID = id
result = append(result, cur)
}
......@@ -258,8 +263,8 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return nil
}
func (s *Service) procFailedTask(t taskEntry, epoch int) {
if t.Task.Epoch != epoch {
func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check or failed status report.
return
......@@ -272,7 +277,7 @@ func (s *Service) procFailedTask(t taskEntry, epoch int) {
}
}()
delete(s.taskQueues.Pending, t.Task.ID)
delete(s.taskQueues.Pending, t.Task.Meta.ID)
t.NumFailure++
if t.NumFailure > s.failureMax {
......@@ -296,7 +301,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return
}
s.procFailedTask(t, epoch)
s.processFailedTask(t, epoch)
}
}
......@@ -345,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
t := s.taskQueues.Todo[0]
t.Task.Epoch++
t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t
s.taskQueues.Pending[t.Task.Meta.ID] = t
err := s.snapshot()
if err != nil {
return err
}
*task = t.Task
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t)
log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Meta)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Task.Epoch))
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil
}
......@@ -373,7 +378,7 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
if !ok {
err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err
return nil
}
// task finished, reset timeout
......@@ -396,14 +401,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
return err
}
// TaskID is a struct which client uses for reports failure.
type TaskID struct {
ID int
Epoch int
}
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(taskID TaskID, dummy *int) error {
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}
......@@ -411,13 +410,13 @@ func (s *Service) TaskFailed(taskID TaskID, dummy *int) error {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID.ID]
t, ok := s.taskQueues.Pending[meta.ID]
if !ok {
err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", taskID)
return err
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Meta)
return nil
}
s.procFailedTask(t, taskID.Epoch)
s.processFailedTask(t, meta.Epoch)
return nil
}
......@@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
if ts[i].Task.ID != i {
if ts[i].Task.Meta.ID != i {
t.Error(ts[i], i)
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册