From 54e8263cae3ffcc597d977330f78fb5020c1dba2 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Fri, 9 Jun 2017 01:43:38 +0000 Subject: [PATCH] implement master server client, remove unnecessary dummy variable --- go/cmd/master/master.go | 50 +-------------- go/master/client.go | 14 +++-- go/master/client_test.go | 49 +++++++++++---- go/master/service.go | 121 ++++++++++++++++++++++++++++++++----- go/pserver/client.go | 12 ++-- go/pserver/service_test.go | 31 ++++------ 6 files changed, 174 insertions(+), 103 deletions(-) diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index 65548b7b68..25cd1cafcd 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -1,78 +1,32 @@ package main import ( - "fmt" "net" "net/http" "net/rpc" - "os" - "path/filepath" "strconv" - "strings" "time" "github.com/namsral/flag" "github.com/PaddlePaddle/Paddle/go/master" - "github.com/PaddlePaddle/recordio" ) func main() { port := flag.Int("port", 8080, "port of the master server.") - dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.") + faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") flag.Parse() - if *dataset == "" { - panic("no dataset specified.") - } - if *faultTolerance { panic("fault tolernance not implemented.") - } - - var chunks []master.Chunk - var paths []string - ss := strings.Split(*dataset, ",") - fmt.Println(ss) - for _, s := range ss { - match, err := filepath.Glob(s) - if err != nil { - panic(err) - } - paths = append(paths, match...) - } - - if len(paths) == 0 { - panic("no valid datset specified.") - } - - for _, path := range paths { - f, err := os.Open(path) - if err != nil { - panic(err) - } - - index, err := recordio.LoadIndex(f) - if err != nil { - panic(err) - } - f.Close() - count := index.NumChunks() - for i := 0; i < count; i++ { - chunk := master.Chunk{ - Path: path, - Index: *index.ChunkIndex(i), - } - chunks = append(chunks, chunk) - } } - s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) + s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) err := rpc.Register(s) if err != nil { panic(err) diff --git a/go/master/client.go b/go/master/client.go index 23ef18f9e2..791db5a975 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -59,16 +59,22 @@ func (c *Client) monitorMaster(addr Addresser) { } } +// SetDataset set dataset for the master server to dispatch. +// +// SetDataset can be call multiple times from different nodes. But +// only the first call will be honored. +func (c *Client) SetDataset(globPaths []string) error { + return c.conn.Call("Service.SetDataset", globPaths, nil) +} + // GetTask gets a new task from the master server. func (c *Client) GetTask() (Task, error) { - var dummy int var t Task - err := c.conn.Call("Service.GetTask", dummy, &t) + err := c.conn.Call("Service.GetTask", 0, &t) return t, err } // TaskFinished tells the master server a task is finished. func (c *Client) TaskFinished(taskID int) error { - var dummy int - return c.conn.Call("Service.TaskFinished", taskID, &dummy) + return c.conn.Call("Service.TaskFinished", taskID, nil) } diff --git a/go/master/client_test.go b/go/master/client_test.go index 4603bdc4d6..5abad0d820 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -5,12 +5,14 @@ import ( "net" "net/http" "net/rpc" + "os" "strconv" "strings" "testing" "time" "github.com/PaddlePaddle/Paddle/go/master" + "github.com/PaddlePaddle/recordio" ) const ( @@ -34,8 +36,7 @@ func init() { port = p go func(l net.Listener) { - chunks := make([]master.Chunk, totalTask) - s := master.NewService(chunks, chunkPerTask, time.Second, 1) + s := master.NewService(chunkPerTask, time.Second, 1) server := rpc.NewServer() err := server.Register(s) if err != nil { @@ -58,21 +59,47 @@ func (a addresser) Address() string { } func TestClientFull(t *testing.T) { + const p = "/tmp/master_client_test_0" + f, err := os.Create(p) + if err != nil { + panic(err) + } + + for i := 0; i < totalTask*chunkPerTask; i++ { + w := recordio.NewWriter(f, -1, -1) + w.Write(nil) + // call Close to force RecordIO writing a chunk. + w.Close() + } + f.Close() + c := master.NewClient(addresser(fmt.Sprintf(":%d", port))) + c.SetDataset([]string{p}) - for i := 0; i < 5*totalTask/chunkPerTask; i++ { - task, err := c.GetTask() - if err != nil { - panic(err) + checkOnePass := func(i int) { + var tasks []master.Task + for i := 0; i < totalTask; i++ { + task, err := c.GetTask() + if err != nil { + t.Fatal(i, err) + } + tasks = append(tasks, task) } - if len(task.Chunks) != chunkPerTask { - t.Fatal("wrong number of chunk per task", len(task.Chunks)) + _, err = c.GetTask() + if err == nil { + t.Fatal(i, "should get error.") } - err = c.TaskFinished(task.ID) - if err != nil { - panic(err) + for _, task := range tasks { + err = c.TaskFinished(task.ID) + if err != nil { + t.Fatal(i, err) + } } } + + for i := 0; i < 10; i++ { + checkOnePass(i) + } } diff --git a/go/master/service.go b/go/master/service.go index 8d6bbecc49..c80037a3b3 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -3,6 +3,8 @@ package master import ( "errors" "log" + "os" + "path/filepath" "sync" "time" @@ -13,18 +15,15 @@ const ( targetTaskCount = 300 ) -// errors -var ( - ErrNoMoreTask = errors.New("no more task for current pass") - ErrPendingTaskNotFound = errors.New("pending task not found") -) - // Service is the master server service. type Service struct { - timeoutDur time.Duration - timeoutMax int + chunksPerTask int + timeoutDur time.Duration + timeoutMax int + ready chan struct{} mu sync.Mutex + initBegan bool taskQueues taskQueues } @@ -63,13 +62,14 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { } // NewService creates a new service. -func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { +func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { s := &Service{} + s.chunksPerTask = chunksPerTask s.timeoutDur = timeoutDur s.timeoutMax = timeoutMax s.taskQueues = taskQueues{} s.taskQueues.Pending = make(map[int]taskEntry) - s.taskQueues.Todo = partition(chunks, chunksPerTask) + s.ready = make(chan struct{}) return s } @@ -104,13 +104,102 @@ func (s *Service) snapshot() error { return nil } +// SetDataset sets dataset to dispatch for the master server. +// +// SetDataset can be call multiple times. But only the first call will +// be honored. +func (s *Service) SetDataset(globPaths []string, dummy *int) error { + if len(globPaths) == 0 { + return errors.New("no dataset specified") + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.initBegan { + // SetDataset already called. All trainer will call + // SetDataset, but we only handle the first one. Treat + // other calls as successful but do nothing. + return nil + } + + s.initBegan = true + + var chunks []Chunk + var paths []string + + for _, s := range globPaths { + match, err := filepath.Glob(s) + if err != nil { + panic(err) + } + paths = append(paths, match...) + } + + if len(paths) == 0 { + return errors.New("no valid datset specified") + } + + for _, path := range paths { + f, err := os.Open(path) + if err != nil { + panic(err) + } + + index, err := recordio.LoadIndex(f) + if err != nil { + return err + } + err = f.Close() + if err != nil { + return err + } + + count := index.NumChunks() + for i := 0; i < count; i++ { + chunk := Chunk{ + Path: path, + Index: *index.ChunkIndex(i), + } + chunks = append(chunks, chunk) + } + } + + s.taskQueues.Todo = partition(chunks, s.chunksPerTask) + + err := s.snapshot() + if err != nil { + return err + } + + close(s.ready) + return nil +} + // GetTask gets a new task from the service. func (s *Service) GetTask(dummy int, task *Task) error { + select { + case <-s.ready: + } + s.mu.Lock() defer s.mu.Unlock() if len(s.taskQueues.Todo) == 0 { - return ErrNoMoreTask + if len(s.taskQueues.Done) == 0 { + if len(s.taskQueues.Pending) == 0 { + return errors.New("all task failed") + } + + // TODO(helin): client need to retry in this + // error case. Gotcha: RPC client can't + // compare returned error with predefined + // erros like io.EOF. Because interface don't + // have same dynamic value when in different + // process. + return errors.New("no more available task") + } + s.taskQueues.Todo = s.taskQueues.Done + s.taskQueues.Todo = nil } t := s.taskQueues.Todo[0] @@ -163,12 +252,16 @@ func (s *Service) GetTask(dummy int, task *Task) error { // TaskFinished tell the service that a task is finished. func (s *Service) TaskFinished(taskID int, dummy *int) error { + select { + case <-s.ready: + } + s.mu.Lock() defer s.mu.Unlock() t, ok := s.taskQueues.Pending[taskID] if !ok { - return ErrPendingTaskNotFound + return errors.New("pending task not found") } // task finished, reset timeout @@ -176,8 +269,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { s.taskQueues.Done = append(s.taskQueues.Done, t) delete(s.taskQueues.Pending, taskID) - if len(s.taskQueues.Todo) == 0 { - s.taskQueues.Todo = s.taskQueues.Done + if len(s.taskQueues.Pending) == 0 { + s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.taskQueues.Done = nil } diff --git a/go/pserver/client.go b/go/pserver/client.go index 7930f012c3..bbe93cbb6b 100644 --- a/go/pserver/client.go +++ b/go/pserver/client.go @@ -102,16 +102,14 @@ func (c *Client) BeginInitParams() bool { // InitParam initializes the parameter on parameter servers. func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { - var dummy int - return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy) + return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil) } // FinishInitParams tells parameter servers client has sent all // parameters to parameter servers as initialization. func (c *Client) FinishInitParams() error { for _, p := range c.pservers { - var dummy int - err := p.Call("Service.FinishInitParams", dummy, &dummy) + err := p.Call("Service.FinishInitParams", 0, nil) if err != nil { return err } @@ -125,8 +123,7 @@ func (c *Client) SendGrads(grads []Gradient) error { errCh := make(chan error, len(grads)) for _, g := range grads { go func(g Gradient) { - var dummy int - err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy) + err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil) errCh <- err }(g) } @@ -205,8 +202,7 @@ func (c *Client) Save(path string) error { errCh := make(chan error, len(c.pservers)) for _, p := range c.pservers { - var dummy int - err := p.Call("Service.Save", path, &dummy) + err := p.Call("Service.Save", path, nil) errCh <- err } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 796492ffb4..c40cecd0b6 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -15,8 +15,7 @@ func TestFull(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - var dummy int - err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy) + err := s.InitParam(pserver.ParameterWithConfig{p, nil}, nil) if err != nil { t.FailNow() } @@ -25,12 +24,12 @@ func TestFull(t *testing.T) { p1.Name = "param_b" p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.ElementType = pserver.Float32 - err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, &dummy) + err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, nil) if err != nil { t.FailNow() } - err = s.FinishInitParams(0, &dummy) + err = s.FinishInitParams(0, nil) if err != nil { t.FailNow() } @@ -46,11 +45,11 @@ func TestFull(t *testing.T) { } g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) - err = s.SendGrad(g1, &dummy) + err = s.SendGrad(g1, nil) if err != nil { t.FailNow() } - err = s.SendGrad(g2, &dummy) + err = s.SendGrad(g2, nil) if err != nil { t.FailNow() @@ -74,23 +73,21 @@ func TestFull(t *testing.T) { func TestMultipleInit(t *testing.T) { s := pserver.NewService() - var dummy int - err := s.FinishInitParams(0, &dummy) + err := s.FinishInitParams(0, nil) if err != nil { t.FailNow() } - err = s.FinishInitParams(0, &dummy) - if err.Error() != pserver.AlreadyInitialized { + err = s.FinishInitParams(0, nil) + if err != pserver.ErrAlreadyInitialized { t.FailNow() } } func TestUninitialized(t *testing.T) { s := pserver.NewService() - var dummy int - err := s.SendGrad(pserver.Gradient{}, &dummy) - if err.Error() != pserver.Uninitialized { + err := s.SendGrad(pserver.Gradient{}, nil) + if err != pserver.ErrUninitialized { t.FailNow() } } @@ -112,8 +109,7 @@ func TestBlockUntilInitialized(t *testing.T) { wg.Add(1) go func() { - var dummy int - err := s.Save("", &dummy) + err := s.Save("", nil) if err != nil { t.FailNow() } @@ -134,13 +130,12 @@ func TestBlockUntilInitialized(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - var dummy int - err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy) + err := s.InitParam(pserver.ParameterWithConfig{p, nil}, nil) if err != nil { t.FailNow() } - err = s.FinishInitParams(0, &dummy) + err = s.FinishInitParams(0, nil) if err != nil { t.FailNow() } -- GitLab