提交 ef67d08c 编写于 作者: G gongweibao 提交者: GitHub

Merge pull request #2719 from gongweibao/taskfail

add TaskFail interface
...@@ -68,7 +68,7 @@ func (c *Client) getRecords() { ...@@ -68,7 +68,7 @@ func (c *Client) getRecords() {
// We treat a task as finished whenever the last data // We treat a task as finished whenever the last data
// instance of the task is read. This is not exactly // instance of the task is read. This is not exactly
// correct, but a reasonable approximation. // correct, but a reasonable approximation.
c.taskFinished(t.ID) c.taskFinished(t.Meta.ID)
} }
} }
...@@ -118,6 +118,11 @@ func (c *Client) taskFinished(taskID int) error { ...@@ -118,6 +118,11 @@ func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil) return c.conn.Call("Service.TaskFinished", taskID, nil)
} }
// TaskFailed tell the master server as task is failed.
func (c *Client) taskFailed(meta TaskMeta) error {
return c.conn.Call("Service.TaskFailed", meta, nil)
}
// NextRecord returns next record in the dataset. // NextRecord returns next record in the dataset.
// //
// NextRecord will block until the next record is available. It is // NextRecord will block until the next record is available. It is
......
...@@ -95,10 +95,16 @@ func TestGetFinishTask(t *testing.T) { ...@@ -95,10 +95,16 @@ func TestGetFinishTask(t *testing.T) {
t.Fatalf("Should get error, pass: %d\n", i) t.Fatalf("Should get error, pass: %d\n", i)
} }
err = c.taskFinished(tasks[0].ID) err = c.taskFinished(tasks[0].Meta.ID)
if err != nil { if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", err, i)
} }
err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
}
tasks = tasks[1:] tasks = tasks[1:]
task, err := c.getTask() task, err := c.getTask()
if err != nil { if err != nil {
...@@ -107,7 +113,7 @@ func TestGetFinishTask(t *testing.T) { ...@@ -107,7 +113,7 @@ func TestGetFinishTask(t *testing.T) {
tasks = append(tasks, task) tasks = append(tasks, task)
for _, task := range tasks { for _, task := range tasks {
err = c.taskFinished(task.ID) err = c.taskFinished(task.Meta.ID)
if err != nil { if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i) t.Fatalf("Error: %v, pass: %d\n", err, i)
} }
......
...@@ -31,30 +31,36 @@ type Chunk struct { ...@@ -31,30 +31,36 @@ type Chunk struct {
Index recordio.Index // chunk index Index recordio.Index // chunk index
} }
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
ID int
Epoch int
}
// Task is the basic unit of data instances assigned to trainers. // Task is the basic unit of data instances assigned to trainers.
type Task struct { type Task struct {
ID int Meta TaskMeta
Chunks []Chunk Chunks []Chunk
} }
type taskEntry struct { type taskEntry struct {
Epoch int Task Task
NumTimeout int // A task fails if it's timeout or trainer reports it exits unnormally.
Task Task NumFailure int
} }
type taskQueues struct { type taskQueues struct {
Todo []taskEntry Todo []taskEntry
Pending map[int]taskEntry // map from task ID to task entry Pending map[int]taskEntry // map from task ID to task entry
Done []taskEntry Done []taskEntry
Failed []Task Failed []taskEntry
} }
// Service is the master server service. // Service is the master server service.
type Service struct { type Service struct {
chunksPerTask int chunksPerTask int
timeoutDur time.Duration timeoutDur time.Duration
timeoutMax int failureMax int
ready chan struct{} ready chan struct{}
store Store store Store
...@@ -73,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -73,7 +79,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
var cur taskEntry var cur taskEntry
for i, c := range chunks { for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.Meta.ID = id
id++ id++
result = append(result, cur) result = append(result, cur)
cur.Task.Chunks = nil cur.Task.Chunks = nil
...@@ -83,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -83,7 +89,7 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
if len(cur.Task.Chunks) > 0 { if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.Meta.ID = id
result = append(result, cur) result = append(result, cur)
} }
...@@ -91,11 +97,11 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -91,11 +97,11 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) { func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
s := &Service{} s := &Service{}
s.chunksPerTask = chunksPerTask s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax s.failureMax = failureMax
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.ready = make(chan struct{}) s.ready = make(chan struct{})
...@@ -257,6 +263,34 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { ...@@ -257,6 +263,34 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return nil return nil
} }
func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check or failed status report.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
}()
delete(s.taskQueues.Pending, t.Task.Meta.ID)
t.NumFailure++
if t.NumFailure > s.failureMax {
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Failed = append(s.taskQueues.Failed, t)
return
}
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() { return func() {
s.mu.Lock() s.mu.Lock()
...@@ -267,30 +301,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { ...@@ -267,30 +301,7 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return return
} }
if t.Epoch != epoch { s.processFailedTask(t, epoch)
// new epoch, task launched after the
// schedule of this timeout check.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
}()
delete(s.taskQueues.Pending, t.Task.ID)
t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return
}
log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
} }
} }
...@@ -339,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -339,18 +350,18 @@ func (s *Service) GetTask(dummy int, task *Task) error {
} }
t := s.taskQueues.Todo[0] t := s.taskQueues.Todo[0]
t.Epoch++ t.Task.Meta.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:] s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t s.taskQueues.Pending[t.Task.Meta.ID] = t
err := s.snapshot() err := s.snapshot()
if err != nil { if err != nil {
return err return err
} }
*task = t.Task *task = t.Task
log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID) log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch)) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
return nil return nil
} }
...@@ -365,13 +376,12 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -365,13 +376,12 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t, ok := s.taskQueues.Pending[taskID] t, ok := s.taskQueues.Pending[taskID]
if !ok { if !ok {
err := errors.New("pending task not found")
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err return nil
} }
// task finished, reset timeout // task finished, reset timeout
t.NumTimeout = 0 t.NumFailure = 0
s.taskQueues.Done = append(s.taskQueues.Done, t) s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
...@@ -389,3 +399,22 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -389,3 +399,22 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
} }
return err return err
} }
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[meta.ID]
if !ok {
log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
return nil
}
s.processFailedTask(t, meta.Epoch)
return nil
}
...@@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) { ...@@ -30,7 +30,7 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100) cs := make([]Chunk, 100)
ts := partition(cs, 20) ts := partition(cs, 20)
for i := range ts { for i := range ts {
if ts[i].Task.ID != i { if ts[i].Task.Meta.ID != i {
t.Error(ts[i], i) t.Error(ts[i], i)
} }
} }
......
...@@ -42,7 +42,8 @@ func initClient() [numPserver]int { ...@@ -42,7 +42,8 @@ func initClient() [numPserver]int {
ports[i] = p ports[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s, err := pserver.NewService(0) var cp pserver.Checkpoint
s, err := pserver.NewService(0, 1, "", nil, cp)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -174,7 +175,7 @@ func TestNativeClient(t *testing.T) { ...@@ -174,7 +175,7 @@ func TestNativeClient(t *testing.T) {
// TODO: tmperary disable etcdClient test for dependency of etcd) // TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) { func EtcdClient(t *testing.T) {
initEtcdClient() initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints) etcdClient := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true)) c2 := client.NewClient(etcdClient, etcdClient.Desired(), selector(true))
ClientTest(t, c2) ClientTest(t, c2)
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册