提交 0bebaa05 编写于 作者: H Helin Wang

fix according to comments

上级 f6148eb2
...@@ -28,6 +28,8 @@ func NewClient(addr Addresser) *Client { ...@@ -28,6 +28,8 @@ func NewClient(addr Addresser) *Client {
func (c *Client) monitorMaster(addr Addresser) { func (c *Client) monitorMaster(addr Addresser) {
lastMaster := "" lastMaster := ""
monitor := func() { monitor := func() {
// get the lastest address of the master server,
// connect to the new address once address changed.
curMaster := addr.Address() curMaster := addr.Address()
if curMaster != lastMaster { if curMaster != lastMaster {
if curMaster == "" { if curMaster == "" {
......
...@@ -11,6 +11,8 @@ import ( ...@@ -11,6 +11,8 @@ import (
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
...@@ -23,6 +25,8 @@ const ( ...@@ -23,6 +25,8 @@ const (
var port int var port int
func init() { func init() {
log.SetLevel(log.ErrorLevel)
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
panic(err) panic(err)
...@@ -91,6 +95,17 @@ func TestClientFull(t *testing.T) { ...@@ -91,6 +95,17 @@ func TestClientFull(t *testing.T) {
t.Fatal(i, "should get error.") t.Fatal(i, "should get error.")
} }
err = c.TaskFinished(tasks[0].ID)
if err != nil {
t.Fatal(err)
}
tasks = tasks[1:]
task, err := c.GetTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)
for _, task := range tasks { for _, task := range tasks {
err = c.TaskFinished(task.ID) err = c.TaskFinished(task.ID)
if err != nil { if err != nil {
......
...@@ -2,12 +2,13 @@ package master ...@@ -2,12 +2,13 @@ package master
import ( import (
"errors" "errors"
"log"
"os" "os"
"path/filepath" "path/filepath"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
...@@ -112,7 +113,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { ...@@ -112,7 +113,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
} }
if len(paths) == 0 { if len(paths) == 0 {
return nil, errors.New("no valid datset specified") return nil, errors.New("no valid dataset specified")
} }
for _, path := range paths { for _, path := range paths {
...@@ -170,6 +171,7 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { ...@@ -170,6 +171,7 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
err = s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
log.Errorln(err)
return err return err
} }
...@@ -178,6 +180,43 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { ...@@ -178,6 +180,43 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return nil return nil
} }
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}
if t.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
}()
delete(s.taskQueues.Pending, t.Task.ID)
t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v failed %d times, discard.\n", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return
}
log.Warningf("Task %v failed %d times, retry.\n", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
}
}
// GetTask gets a new task from the service. // GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error { func (s *Service) GetTask(dummy int, task *Task) error {
select { select {
...@@ -190,19 +229,25 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -190,19 +229,25 @@ func (s *Service) GetTask(dummy int, task *Task) error {
if len(s.taskQueues.Todo) == 0 { if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 { if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 { if len(s.taskQueues.Pending) == 0 {
return errors.New("all task failed") err := errors.New("all task failed")
log.Warningln(err)
return err
} }
// TODO(helin): client need to retry in this // TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't // error case. Gotcha: RPC client can't
// compare returned error with predefined // compare returned error with predefined
// errors like io.EOF. Because interface don't // errors like io.EOF, because interface don't
// have same dynamic value when in different // have same dynamic value when in different
// process. // process. So we need to figure out a way for
return errors.New("no more available task") // client to check this error correctly.
err := errors.New("no more available task")
log.Warningln(err)
return err
} }
s.taskQueues.Todo = s.taskQueues.Done s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Todo = nil s.taskQueues.Done = nil
log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
} }
t := s.taskQueues.Todo[0] t := s.taskQueues.Todo[0]
...@@ -215,41 +260,9 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -215,41 +260,9 @@ func (s *Service) GetTask(dummy int, task *Task) error {
} }
*task = t.Task *task = t.Task
log.Infof("Task #%d dispatched\n", task.ID)
time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() { time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
return func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}
if t.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Println(err)
}
}()
delete(s.taskQueues.Pending, t.Task.ID)
t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return
}
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
}
}(t.Task.ID, t.Epoch))
return nil return nil
} }
...@@ -262,9 +275,13 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -262,9 +275,13 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
log.Infof("Task %d finished\n", taskID)
t, ok := s.taskQueues.Pending[taskID] t, ok := s.taskQueues.Pending[taskID]
if !ok { if !ok {
return errors.New("pending task not found") err := errors.New("pending task not found")
log.Warningln(err)
return err
} }
// task finished, reset timeout // task finished, reset timeout
...@@ -272,10 +289,15 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -272,10 +289,15 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.taskQueues.Done = append(s.taskQueues.Done, t) s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
if len(s.taskQueues.Pending) == 0 { if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil s.taskQueues.Done = nil
} }
return s.snapshot() err := s.snapshot()
if err != nil {
log.Errorln(err)
}
return err
} }
...@@ -57,26 +57,29 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -57,26 +57,29 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
} }
for i := range lastServers { for i := range lastServers {
if lastServers[i].Addr != curServers[i].Addr { if lastServers[i].Addr == curServers[i].Addr {
if curServers[i].Addr == "" { continue
err := c.pservers[i].Close() }
if err != nil {
log.Println(err)
}
continue
}
err := c.pservers[i].Connect(curServers[i].Addr) if curServers[i].Addr == "" {
err := c.pservers[i].Close()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers[i].Addr = lastServers[i].Addr
} }
continue
} }
err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil {
log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers[i].Addr = lastServers[i].Addr
}
} }
lastServers = curServers lastServers = curServers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册