提交 0bebaa05 编写于 作者: H Helin Wang

fix according to comments

上级 f6148eb2
......@@ -28,6 +28,8 @@ func NewClient(addr Addresser) *Client {
func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
// get the lastest address of the master server,
// connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster {
if curMaster == "" {
......
......@@ -11,6 +11,8 @@ import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
......@@ -23,6 +25,8 @@ const (
var port int
func init() {
log.SetLevel(log.ErrorLevel)
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
......@@ -91,6 +95,17 @@ func TestClientFull(t *testing.T) {
t.Fatal(i, "should get error.")
}
err = c.TaskFinished(tasks[0].ID)
if err != nil {
t.Fatal(err)
}
tasks = tasks[1:]
task, err := c.GetTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.TaskFinished(task.ID)
if err != nil {
......
......@@ -2,12 +2,13 @@ package master
import (
"errors"
"log"
"os"
"path/filepath"
"sync"
"time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/recordio"
)
......@@ -112,7 +113,7 @@ func readChunks(globPaths []string) ([]Chunk, error) {
}
if len(paths) == 0 {
return nil, errors.New("no valid datset specified")
return nil, errors.New("no valid dataset specified")
}
for _, path := range paths {
......@@ -170,6 +171,7 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
err = s.snapshot()
if err != nil {
log.Errorln(err)
return err
}
......@@ -178,6 +180,43 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
return nil
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}
if t.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
}()
delete(s.taskQueues.Pending, t.Task.ID)
t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
log.Warningf("Task %v failed %d times, discard.\n", t.Task, t.NumTimeout)
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return
}
log.Warningf("Task %v failed %d times, retry.\n", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
}
}
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
select {
......@@ -190,19 +229,25 @@ func (s *Service) GetTask(dummy int, task *Task) error {
if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
return errors.New("all task failed")
err := errors.New("all task failed")
log.Warningln(err)
return err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF. Because interface don't
// errors like io.EOF, because interface don't
// have same dynamic value when in different
// process.
return errors.New("no more available task")
// process. So we need to figure out a way for
// client to check this error correctly.
err := errors.New("no more available task")
log.Warningln(err)
return err
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Todo = nil
s.taskQueues.Done = nil
log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
}
t := s.taskQueues.Todo[0]
......@@ -215,41 +260,9 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
*task = t.Task
log.Infof("Task #%d dispatched\n", task.ID)
time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() {
return func() {
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return
}
if t.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Println(err)
}
}()
delete(s.taskQueues.Pending, t.Task.ID)
t.NumTimeout++
if t.NumTimeout > s.timeoutMax {
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
return
}
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
}
}(t.Task.ID, t.Epoch))
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
return nil
}
......@@ -262,9 +275,13 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.mu.Lock()
defer s.mu.Unlock()
log.Infof("Task %d finished\n", taskID)
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return errors.New("pending task not found")
err := errors.New("pending task not found")
log.Warningln(err)
return err
}
// task finished, reset timeout
......@@ -272,10 +289,15 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID)
if len(s.taskQueues.Pending) == 0 {
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
}
return s.snapshot()
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
return err
}
......@@ -57,7 +57,10 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
}
for i := range lastServers {
if lastServers[i].Addr != curServers[i].Addr {
if lastServers[i].Addr == curServers[i].Addr {
continue
}
if curServers[i].Addr == "" {
err := c.pservers[i].Close()
if err != nil {
......@@ -76,7 +79,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// to retry next time.
curServers[i].Addr = lastServers[i].Addr
}
}
}
lastServers = curServers
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册