diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index d1f3d7d76c438670faf6677b01e790c5ebe1f2cb..25cd1cafcdf328094a019638f37f908591f5f374 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -1,80 +1,32 @@ package main import ( - "fmt" "net" "net/http" "net/rpc" - "os" - "path/filepath" "strconv" - "strings" "time" "github.com/namsral/flag" "github.com/PaddlePaddle/Paddle/go/master" - "github.com/PaddlePaddle/recordio" ) func main() { port := flag.Int("port", 8080, "port of the master server.") - dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.") + faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") flag.Parse() - if *dataset == "" { - panic("no dataset specified.") - } - if *faultTolerance { panic("fault tolernance not implemented.") - } - - var chunks []master.Chunk - var paths []string - ss := strings.Split(*dataset, ",") - fmt.Println(ss) - for _, s := range ss { - match, err := filepath.Glob(s) - if err != nil { - panic(err) - } - paths = append(paths, match...) - } - - if len(paths) == 0 { - panic("no valid datset specified.") - } - - idx := 0 - for _, path := range paths { - f, err := os.Open(path) - if err != nil { - panic(err) - } - - index, err := recordio.LoadIndex(f) - if err != nil { - panic(err) - } - f.Close() - count := index.NumChunks() - for i := 0; i < count; i++ { - chunk := master.Chunk{ - Idx: idx, - Path: path, - Index: *index.ChunkIndex(i), - } - chunks = append(chunks, chunk) - } } - s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) + s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) err := rpc.Register(s) if err != nil { panic(err) diff --git a/go/pserver/internal/connection/conn.go b/go/connection/conn.go similarity index 82% rename from go/pserver/internal/connection/conn.go rename to go/connection/conn.go index 1c04f117254054741b7d45fb16462b5ce84a2aea..bc9b5f0617e35f049c3e14f0b441aca2033f9645 100644 --- a/go/pserver/internal/connection/conn.go +++ b/go/connection/conn.go @@ -2,6 +2,7 @@ package connection import ( "errors" + "log" "net/rpc" "sync" ) @@ -21,6 +22,18 @@ func New() *Conn { return c } +// Close closes the connection. +func (c *Conn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.client == nil { + return nil + } + + return c.client.Close() +} + // Connect connects the connection to a address. func (c *Conn) Connect(addr string) error { c.mu.Lock() @@ -50,12 +63,20 @@ func (c *Conn) Connect(addr string) error { c.waitConn = nil } } else { + err := client.Close() + if err != nil { + log.Println(err) + } + return errors.New("client already set from a concurrent goroutine") } return nil } +// TODO(helin): refactor Call to be able to perform given retry +// policy. + // Call make a RPC call. // // Call will be blocked until the connection to remote RPC service diff --git a/go/master/client.go b/go/master/client.go new file mode 100644 index 0000000000000000000000000000000000000000..20c66340dc28bc514b6c51583dd94830c42a41bf --- /dev/null +++ b/go/master/client.go @@ -0,0 +1,82 @@ +package master + +import ( + "log" + "time" + + "github.com/PaddlePaddle/Paddle/go/connection" +) + +// Addresser provide the address of the master server. +type Addresser interface { + Address() string +} + +// Client is the client of the master server. +type Client struct { + conn *connection.Conn +} + +// NewClient creates a new Client. +func NewClient(addr Addresser) *Client { + c := &Client{} + c.conn = connection.New() + go c.monitorMaster(addr) + return c +} + +func (c *Client) monitorMaster(addr Addresser) { + lastMaster := "" + monitor := func() { + // get the lastest address of the master server, + // connect to the new address once address changed. + curMaster := addr.Address() + if curMaster != lastMaster { + if curMaster == "" { + err := c.conn.Close() + if err != nil { + log.Println(err) + } + } else { + err := c.conn.Connect(curMaster) + if err != nil { + log.Println(err) + + // connect to addr failed, set + // to last known addr in order + // to retry next time. + curMaster = lastMaster + } + + } + } + + lastMaster = curMaster + } + + monitor() + ticker := time.NewTicker(10 * time.Second) + for _ = range ticker.C { + monitor() + } +} + +// SetDataset set dataset for the master server to dispatch. +// +// SetDataset can be call multiple times from different nodes. But +// only the first call will be honored. +func (c *Client) SetDataset(globPaths []string) error { + return c.conn.Call("Service.SetDataset", globPaths, nil) +} + +// GetTask gets a new task from the master server. +func (c *Client) GetTask() (Task, error) { + var t Task + err := c.conn.Call("Service.GetTask", 0, &t) + return t, err +} + +// TaskFinished tells the master server a task is finished. +func (c *Client) TaskFinished(taskID int) error { + return c.conn.Call("Service.TaskFinished", taskID, nil) +} diff --git a/go/master/client_test.go b/go/master/client_test.go new file mode 100644 index 0000000000000000000000000000000000000000..df708ad7912c07205ae0cee8d2ab1c06d65223cc --- /dev/null +++ b/go/master/client_test.go @@ -0,0 +1,120 @@ +package master_test + +import ( + "fmt" + "net" + "net/http" + "net/rpc" + "os" + "strconv" + "strings" + "testing" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/PaddlePaddle/Paddle/go/master" + "github.com/PaddlePaddle/recordio" +) + +const ( + totalTask = 20 + chunkPerTask = 10 +) + +var port int + +func init() { + log.SetLevel(log.ErrorLevel) + + l, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + + ss := strings.Split(l.Addr().String(), ":") + p, err := strconv.Atoi(ss[len(ss)-1]) + if err != nil { + panic(err) + } + port = p + + go func(l net.Listener) { + s := master.NewService(chunkPerTask, time.Second, 1) + server := rpc.NewServer() + err := server.Register(s) + if err != nil { + panic(err) + } + + mux := http.NewServeMux() + mux.Handle(rpc.DefaultRPCPath, server) + err = http.Serve(l, mux) + if err != nil { + panic(err) + } + }(l) +} + +type addresser string + +func (a addresser) Address() string { + return string(a) +} + +func TestClientFull(t *testing.T) { + const p = "/tmp/master_client_test_0" + f, err := os.Create(p) + if err != nil { + panic(err) + } + + for i := 0; i < totalTask*chunkPerTask; i++ { + w := recordio.NewWriter(f, -1, -1) + w.Write(nil) + // call Close to force RecordIO writing a chunk. + w.Close() + } + f.Close() + + c := master.NewClient(addresser(fmt.Sprintf(":%d", port))) + c.SetDataset([]string{p}) + + checkOnePass := func(i int) { + var tasks []master.Task + for i := 0; i < totalTask; i++ { + task, err := c.GetTask() + if err != nil { + t.Fatal(i, err) + } + tasks = append(tasks, task) + } + + _, err = c.GetTask() + if err == nil { + t.Fatal(i, "should get error.") + } + + err = c.TaskFinished(tasks[0].ID) + if err != nil { + t.Fatal(err) + } + tasks = tasks[1:] + task, err := c.GetTask() + if err != nil { + t.Fatal(err) + } + tasks = append(tasks, task) + + for _, task := range tasks { + err = c.TaskFinished(task.ID) + if err != nil { + t.Fatal(i, err) + } + } + } + + for i := 0; i < 10; i++ { + checkOnePass(i) + } +} diff --git a/go/master/service.go b/go/master/service.go index ab17a62f3854c1e32d731037fcc9857260d03781..1e2a34972bb4763688d7df97d768320cd6f9e9d4 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -2,29 +2,25 @@ package master import ( "errors" - "log" + "os" + "path/filepath" "sync" "time" - "github.com/PaddlePaddle/recordio" -) + log "github.com/sirupsen/logrus" -const ( - targetTaskCount = 300 -) - -// errors -var ( - ErrNoMoreTask = errors.New("no more task for current pass") - ErrPendingTaskNotFound = errors.New("pending task not found") + "github.com/PaddlePaddle/recordio" ) // Service is the master server service. type Service struct { - timeoutDur time.Duration - timeoutMax int + chunksPerTask int + timeoutDur time.Duration + timeoutMax int + ready chan struct{} mu sync.Mutex + initDone bool taskQueues taskQueues } @@ -55,7 +51,6 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { if len(cur.Task.Chunks) > 0 { cur.Task.ID = id - id++ result = append(result, cur) } @@ -63,21 +58,21 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { } // NewService creates a new service. -func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { +func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { s := &Service{} + s.chunksPerTask = chunksPerTask s.timeoutDur = timeoutDur s.timeoutMax = timeoutMax s.taskQueues = taskQueues{} s.taskQueues.Pending = make(map[int]taskEntry) - s.taskQueues.Todo = partition(chunks, chunksPerTask) + s.ready = make(chan struct{}) return s } // Chunk is a chunk of data consisted of several data instances. type Chunk struct { - Idx int // index of the chunk within the file Path string - Index recordio.Index // block index + Index recordio.Index // chunk index } // Task is the basic unit of data instances assigned to trainers. @@ -105,74 +100,205 @@ func (s *Service) snapshot() error { return nil } -// GetTask gets a new task from the service. -func (s *Service) GetTask(dummy int, task *Task) error { +func readChunks(globPaths []string) ([]Chunk, error) { + var chunks []Chunk + var paths []string + + for _, s := range globPaths { + match, err := filepath.Glob(s) + if err != nil { + return nil, err + } + paths = append(paths, match...) + } + + if len(paths) == 0 { + return nil, errors.New("no valid dataset specified") + } + + for _, path := range paths { + f, err := os.Open(path) + if err != nil { + return nil, err + } + + index, err := recordio.LoadIndex(f) + if err != nil { + return nil, err + } + err = f.Close() + if err != nil { + return nil, err + } + + count := index.NumChunks() + for i := 0; i < count; i++ { + chunk := Chunk{ + Path: path, + Index: *index.ChunkIndex(i), + } + chunks = append(chunks, chunk) + } + } + + return chunks, nil +} + +// SetDataset sets dataset to dispatch for the master server. +// +// SetDataset can be call multiple times. But only the first call will +// be honored. +func (s *Service) SetDataset(globPaths []string, dummy *int) error { + if len(globPaths) == 0 { + return errors.New("no dataset specified") + } + s.mu.Lock() defer s.mu.Unlock() + if s.initDone { + // Already initialized. All trainer will call + // SetDataset, but we only handle the first one. Treat + // other calls as successful but do nothing. + return nil + } - if len(s.taskQueues.Todo) == 0 { - return ErrNoMoreTask + chunks, err := readChunks(globPaths) + if err != nil { + return err } - t := s.taskQueues.Todo[0] - t.Epoch++ - s.taskQueues.Todo = s.taskQueues.Todo[1:] - s.taskQueues.Pending[t.Task.ID] = t - err := s.snapshot() + s.taskQueues.Todo = partition(chunks, s.chunksPerTask) + + err = s.snapshot() if err != nil { + log.Errorln(err) return err } - time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() { - return func() { - s.mu.Lock() - defer s.mu.Unlock() + close(s.ready) + s.initDone = true + return nil +} - t, ok := s.taskQueues.Pending[taskID] - if !ok { - return - } +func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { + return func() { + s.mu.Lock() + defer s.mu.Unlock() - if t.Epoch != epoch { - // new epoch, task launched after the - // schedule of this timeout check. - return + t, ok := s.taskQueues.Pending[taskID] + if !ok { + return + } + + if t.Epoch != epoch { + // new epoch, task launched after the + // schedule of this timeout check. + return + } + + defer func() { + err := s.snapshot() + if err != nil { + log.Errorln(err) } + }() + + delete(s.taskQueues.Pending, t.Task.ID) - defer func() { - err := s.snapshot() - if err != nil { - log.Println(err) - } - }() + t.NumTimeout++ + if t.NumTimeout > s.timeoutMax { + log.Warningf("Task %v failed %d times, discard.\n", t.Task, t.NumTimeout) + s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) + return + } - delete(s.taskQueues.Pending, t.Task.ID) + log.Warningf("Task %v failed %d times, retry.\n", t.Task, t.NumTimeout) + s.taskQueues.Todo = append(s.taskQueues.Todo, t) + } +} - t.NumTimeout++ - if t.NumTimeout > s.timeoutMax { - s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) - return +// GetTask gets a new task from the service. +func (s *Service) GetTask(dummy int, task *Task) error { + select { + case <-s.ready: + } + + s.mu.Lock() + defer s.mu.Unlock() + + if len(s.taskQueues.Todo) == 0 { + if len(s.taskQueues.Done) == 0 { + if len(s.taskQueues.Pending) == 0 { + err := errors.New("all task failed") + log.Warningln(err) + return err } - s.taskQueues.Todo = append(s.taskQueues.Todo, t) + // TODO(helin): client need to retry in this + // error case. Gotcha: RPC client can't + // compare returned error with predefined + // errors like io.EOF, because the error + // instance deserialized from RPC is a + // different instance than the error defined + // in package. So we need to figure out a way + // for client to check this error correctly. + err := errors.New("no more available task") + log.Warningln(err) + return err } - }(t.Task.ID, t.Epoch)) + s.taskQueues.Todo = s.taskQueues.Done + s.taskQueues.Done = nil + log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.") + } + + t := s.taskQueues.Todo[0] + t.Epoch++ + s.taskQueues.Todo = s.taskQueues.Todo[1:] + s.taskQueues.Pending[t.Task.ID] = t + err := s.snapshot() + if err != nil { + return err + } + + *task = t.Task + log.Infof("Task #%d dispatched\n", task.ID) + + time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch)) return nil } // TaskFinished tell the service that a task is finished. func (s *Service) TaskFinished(taskID int, dummy *int) error { + select { + case <-s.ready: + } + s.mu.Lock() defer s.mu.Unlock() + log.Infof("Task %d finished\n", taskID) + t, ok := s.taskQueues.Pending[taskID] if !ok { - return ErrPendingTaskNotFound + err := errors.New("pending task not found") + log.Warningln(err) + return err } // task finished, reset timeout t.NumTimeout = 0 s.taskQueues.Done = append(s.taskQueues.Done, t) delete(s.taskQueues.Pending, taskID) - return s.snapshot() + + if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { + log.Infoln("No more todo and pending task, start a new pass.") + s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) + s.taskQueues.Done = nil + } + + err := s.snapshot() + if err != nil { + log.Errorln(err) + } + return err } diff --git a/go/pserver/client.go b/go/pserver/client.go index f8bd0aa59f30ec7e2b2d318929af96135d3128ed..afe1eecd015b84684329c0e624f3753852d7a8ce 100644 --- a/go/pserver/client.go +++ b/go/pserver/client.go @@ -6,7 +6,7 @@ import ( "sort" "time" - "github.com/PaddlePaddle/Paddle/go/pserver/internal/connection" + "github.com/PaddlePaddle/Paddle/go/connection" ) // TODO(helin): add RPC call retry logic @@ -47,7 +47,7 @@ func NewClient(l Lister, pserverNum int, sel Selector) *Client { // monitorPservers monitors pserver addresses, and updates connection // when the address changes. func (c *Client) monitorPservers(l Lister, pserverNum int) { - knownServers := make([]Server, pserverNum) + lastServers := make([]Server, pserverNum) ticker := time.NewTicker(10 * time.Second) monitor := func() { curServers := make([]Server, pserverNum) @@ -56,25 +56,37 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { curServers[l.Index] = l } - for i := range knownServers { - if knownServers[i].Addr != curServers[i].Addr { - err := c.pservers[i].Connect(curServers[i].Addr) + for i := range lastServers { + if lastServers[i].Addr == curServers[i].Addr { + continue + } + + if curServers[i].Addr == "" { + err := c.pservers[i].Close() if err != nil { log.Println(err) - - // connect to addr failed, set - // to last known addr in order - // to retry next time. - curServers[i].Addr = knownServers[i].Addr } + + continue } + + err := c.pservers[i].Connect(curServers[i].Addr) + if err != nil { + log.Println(err) + + // connect to addr failed, set + // to last known addr in order + // to retry next time. + curServers[i].Addr = lastServers[i].Addr + } + } - knownServers = curServers + lastServers = curServers } monitor() - for _ = range ticker.C { + for range ticker.C { monitor() } } @@ -93,16 +105,14 @@ func (c *Client) BeginInitParams() bool { // InitParam initializes the parameter on parameter servers. func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { - var dummy int - return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy) + return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil) } // FinishInitParams tells parameter servers client has sent all // parameters to parameter servers as initialization. func (c *Client) FinishInitParams() error { for _, p := range c.pservers { - var dummy int - err := p.Call("Service.FinishInitParams", dummy, &dummy) + err := p.Call("Service.FinishInitParams", 0, nil) if err != nil { return err } @@ -116,8 +126,7 @@ func (c *Client) SendGrads(grads []Gradient) error { errCh := make(chan error, len(grads)) for _, g := range grads { go func(g Gradient) { - var dummy int - err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy) + err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil) errCh <- err }(g) } @@ -196,8 +205,7 @@ func (c *Client) Save(path string) error { errCh := make(chan error, len(c.pservers)) for _, p := range c.pservers { - var dummy int - err := p.Call("Service.Save", path, &dummy) + err := p.Call("Service.Save", path, nil) errCh <- err } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 796492ffb47f109b1d47101712195903b8dc8457..b746d13e1ca71e697c464f84d915af029d37120c 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -15,8 +15,7 @@ func TestFull(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - var dummy int - err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy) + err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) if err != nil { t.FailNow() } @@ -25,12 +24,12 @@ func TestFull(t *testing.T) { p1.Name = "param_b" p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.ElementType = pserver.Float32 - err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, &dummy) + err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil) if err != nil { t.FailNow() } - err = s.FinishInitParams(0, &dummy) + err = s.FinishInitParams(0, nil) if err != nil { t.FailNow() } @@ -46,11 +45,11 @@ func TestFull(t *testing.T) { } g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) - err = s.SendGrad(g1, &dummy) + err = s.SendGrad(g1, nil) if err != nil { t.FailNow() } - err = s.SendGrad(g2, &dummy) + err = s.SendGrad(g2, nil) if err != nil { t.FailNow() @@ -74,13 +73,12 @@ func TestFull(t *testing.T) { func TestMultipleInit(t *testing.T) { s := pserver.NewService() - var dummy int - err := s.FinishInitParams(0, &dummy) + err := s.FinishInitParams(0, nil) if err != nil { t.FailNow() } - err = s.FinishInitParams(0, &dummy) + err = s.FinishInitParams(0, nil) if err.Error() != pserver.AlreadyInitialized { t.FailNow() } @@ -88,8 +86,7 @@ func TestMultipleInit(t *testing.T) { func TestUninitialized(t *testing.T) { s := pserver.NewService() - var dummy int - err := s.SendGrad(pserver.Gradient{}, &dummy) + err := s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() } @@ -98,13 +95,14 @@ func TestUninitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) { s := pserver.NewService() ch := make(chan struct{}, 2) + errCh := make(chan error, 2) var wg sync.WaitGroup wg.Add(1) go func() { var param pserver.Parameter err := s.GetParam("param_a", ¶m) if err != nil { - t.FailNow() + errCh <- err } wg.Done() ch <- struct{}{} @@ -112,10 +110,9 @@ func TestBlockUntilInitialized(t *testing.T) { wg.Add(1) go func() { - var dummy int - err := s.Save("", &dummy) + err := s.Save("", nil) if err != nil { - t.FailNow() + errCh <- err } wg.Done() ch <- struct{}{} @@ -127,6 +124,8 @@ func TestBlockUntilInitialized(t *testing.T) { case <-ch: // some function returned before initialization is completed. t.FailNow() + case <-errCh: + t.FailNow() default: } @@ -134,13 +133,12 @@ func TestBlockUntilInitialized(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - var dummy int - err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy) + err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) if err != nil { t.FailNow() } - err = s.FinishInitParams(0, &dummy) + err = s.FinishInitParams(0, nil) if err != nil { t.FailNow() }