From 0bebaa05beda6aca1d9cbedc3fb87c9978cd7df6 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Tue, 13 Jun 2017 21:41:57 +0000 Subject: [PATCH] fix according to comments --- go/master/client.go | 2 + go/master/client_test.go | 15 ++++++ go/master/service.go | 110 +++++++++++++++++++++++---------------- go/pserver/client.go | 33 ++++++------ 4 files changed, 101 insertions(+), 59 deletions(-) diff --git a/go/master/client.go b/go/master/client.go index 791db5a9753..20c66340dc2 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -28,6 +28,8 @@ func NewClient(addr Addresser) *Client { func (c *Client) monitorMaster(addr Addresser) { lastMaster := "" monitor := func() { + // get the lastest address of the master server, + // connect to the new address once address changed. curMaster := addr.Address() if curMaster != lastMaster { if curMaster == "" { diff --git a/go/master/client_test.go b/go/master/client_test.go index 5abad0d8208..df708ad7912 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + log "github.com/sirupsen/logrus" + "github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/recordio" ) @@ -23,6 +25,8 @@ const ( var port int func init() { + log.SetLevel(log.ErrorLevel) + l, err := net.Listen("tcp", ":0") if err != nil { panic(err) @@ -91,6 +95,17 @@ func TestClientFull(t *testing.T) { t.Fatal(i, "should get error.") } + err = c.TaskFinished(tasks[0].ID) + if err != nil { + t.Fatal(err) + } + tasks = tasks[1:] + task, err := c.GetTask() + if err != nil { + t.Fatal(err) + } + tasks = append(tasks, task) + for _, task := range tasks { err = c.TaskFinished(task.ID) if err != nil { diff --git a/go/master/service.go b/go/master/service.go index 6d6a4e30ab1..c2ab3cc6d82 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -2,12 +2,13 @@ package master import ( "errors" - "log" "os" "path/filepath" "sync" "time" + log "github.com/sirupsen/logrus" + "github.com/PaddlePaddle/recordio" ) @@ -112,7 +113,7 @@ func readChunks(globPaths []string) ([]Chunk, error) { } if len(paths) == 0 { - return nil, errors.New("no valid datset specified") + return nil, errors.New("no valid dataset specified") } for _, path := range paths { @@ -170,6 +171,7 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { err = s.snapshot() if err != nil { + log.Errorln(err) return err } @@ -178,6 +180,43 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { return nil } +func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { + return func() { + s.mu.Lock() + defer s.mu.Unlock() + + t, ok := s.taskQueues.Pending[taskID] + if !ok { + return + } + + if t.Epoch != epoch { + // new epoch, task launched after the + // schedule of this timeout check. + return + } + + defer func() { + err := s.snapshot() + if err != nil { + log.Errorln(err) + } + }() + + delete(s.taskQueues.Pending, t.Task.ID) + + t.NumTimeout++ + if t.NumTimeout > s.timeoutMax { + log.Warningf("Task %v failed %d times, discard.\n", t.Task, t.NumTimeout) + s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) + return + } + + log.Warningf("Task %v failed %d times, retry.\n", t.Task, t.NumTimeout) + s.taskQueues.Todo = append(s.taskQueues.Todo, t) + } +} + // GetTask gets a new task from the service. func (s *Service) GetTask(dummy int, task *Task) error { select { @@ -190,19 +229,25 @@ func (s *Service) GetTask(dummy int, task *Task) error { if len(s.taskQueues.Todo) == 0 { if len(s.taskQueues.Done) == 0 { if len(s.taskQueues.Pending) == 0 { - return errors.New("all task failed") + err := errors.New("all task failed") + log.Warningln(err) + return err } // TODO(helin): client need to retry in this // error case. Gotcha: RPC client can't // compare returned error with predefined - // errors like io.EOF. Because interface don't + // errors like io.EOF, because interface don't // have same dynamic value when in different - // process. - return errors.New("no more available task") + // process. So we need to figure out a way for + // client to check this error correctly. + err := errors.New("no more available task") + log.Warningln(err) + return err } s.taskQueues.Todo = s.taskQueues.Done - s.taskQueues.Todo = nil + s.taskQueues.Done = nil + log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.") } t := s.taskQueues.Todo[0] @@ -215,41 +260,9 @@ func (s *Service) GetTask(dummy int, task *Task) error { } *task = t.Task + log.Infof("Task #%d dispatched\n", task.ID) - time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() { - return func() { - s.mu.Lock() - defer s.mu.Unlock() - - t, ok := s.taskQueues.Pending[taskID] - if !ok { - return - } - - if t.Epoch != epoch { - // new epoch, task launched after the - // schedule of this timeout check. - return - } - - defer func() { - err := s.snapshot() - if err != nil { - log.Println(err) - } - }() - - delete(s.taskQueues.Pending, t.Task.ID) - - t.NumTimeout++ - if t.NumTimeout > s.timeoutMax { - s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) - return - } - - s.taskQueues.Todo = append(s.taskQueues.Todo, t) - } - }(t.Task.ID, t.Epoch)) + time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch)) return nil } @@ -262,9 +275,13 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { s.mu.Lock() defer s.mu.Unlock() + log.Infof("Task %d finished\n", taskID) + t, ok := s.taskQueues.Pending[taskID] if !ok { - return errors.New("pending task not found") + err := errors.New("pending task not found") + log.Warningln(err) + return err } // task finished, reset timeout @@ -272,10 +289,15 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { s.taskQueues.Done = append(s.taskQueues.Done, t) delete(s.taskQueues.Pending, taskID) - if len(s.taskQueues.Pending) == 0 { + if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { + log.Infoln("No more todo and pending task, start a new pass.") s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.taskQueues.Done = nil } - return s.snapshot() + err := s.snapshot() + if err != nil { + log.Errorln(err) + } + return err } diff --git a/go/pserver/client.go b/go/pserver/client.go index d8c65b2e137..afe1eecd015 100644 --- a/go/pserver/client.go +++ b/go/pserver/client.go @@ -57,26 +57,29 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { } for i := range lastServers { - if lastServers[i].Addr != curServers[i].Addr { - if curServers[i].Addr == "" { - err := c.pservers[i].Close() - if err != nil { - log.Println(err) - } - - continue - } + if lastServers[i].Addr == curServers[i].Addr { + continue + } - err := c.pservers[i].Connect(curServers[i].Addr) + if curServers[i].Addr == "" { + err := c.pservers[i].Close() if err != nil { log.Println(err) - - // connect to addr failed, set - // to last known addr in order - // to retry next time. - curServers[i].Addr = lastServers[i].Addr } + + continue } + + err := c.pservers[i].Connect(curServers[i].Addr) + if err != nil { + log.Println(err) + + // connect to addr failed, set + // to last known addr in order + // to retry next time. + curServers[i].Addr = lastServers[i].Addr + } + } lastServers = curServers -- GitLab