From fa5c3f1f736a658fa2ada9edf716b922d45ca563 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 14 Jun 2017 21:51:45 +0000 Subject: [PATCH] implement master client, Go part --- go/master/c/client.go | 81 ++++++++++++++++++++ go/master/client.go | 53 ++++++++++++- go/master/client_internal_test.go | 120 ++++++++++++++++++++++++++++++ go/master/client_test.go | 85 ++++++--------------- go/master/service.go | 26 +++++-- 5 files changed, 290 insertions(+), 75 deletions(-) create mode 100644 go/master/c/client.go create mode 100644 go/master/client_internal_test.go diff --git a/go/master/c/client.go b/go/master/c/client.go new file mode 100644 index 0000000000..220184c3af --- /dev/null +++ b/go/master/c/client.go @@ -0,0 +1,81 @@ +package main + +/* + +typedef int paddle_master_client; +*/ +import "C" + +import ( + "log" + "sync" + "unsafe" + + "github.com/PaddlePaddle/Paddle/go/master" +) + +var mu sync.Mutex +var handleMap = make(map[C.paddle_master_client]*master.Client) +var curHandle C.paddle_master_client + +func add(c *master.Client) C.paddle_master_client { + mu.Lock() + defer mu.Unlock() + client := curHandle + curHandle++ + handleMap[client] = c + return client +} + +func get(client C.paddle_master_client) *master.Client { + mu.Lock() + defer mu.Unlock() + return handleMap[client] +} + +func remove(client C.paddle_master_client) *master.Client { + mu.Lock() + defer mu.Unlock() + h := handleMap[client] + delete(handleMap, client) + return h +} + +type addresser string + +func (a addresser) Address() string { + return string(a) +} + +//paddle_new_master_client +func paddle_new_master_client(addr *C.char, buf_size C.int) C.paddle_master_client { + a := C.GoString(addr) + c := master.NewClient(addresser(a), int(buf_size)) + return add(c) +} + +//export paddle_new_etcd_master_client +func paddle_new_etcd_master_client(etcd_addr *C.char) C.paddle_master_client { + // TODO(helin): fault tolerant master client using etcd. + panic("not implemented.") +} + +//export paddle_set_dataset +func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int { + c := get(client) + var paths []string + for i := 0; i < int(size); i++ { + ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(size))) + str := C.GoString(*ptr) + paths = append(paths, str) + } + err := c.SetDataset(paths) + if err != nil { + log.Println(err) + return -1 + } + + return 0 +} + +func main() {} diff --git a/go/master/client.go b/go/master/client.go index 20c66340dc..1c8a8d73d0 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -2,9 +2,11 @@ package master import ( "log" + "os" "time" "github.com/PaddlePaddle/Paddle/go/connection" + "github.com/PaddlePaddle/recordio" ) // Addresser provide the address of the master server. @@ -15,16 +17,51 @@ type Addresser interface { // Client is the client of the master server. type Client struct { conn *connection.Conn + ch chan []byte } // NewClient creates a new Client. -func NewClient(addr Addresser) *Client { +// +// bufSize is the record buffer size. NextRecord will read from the +// buffer. +func NewClient(addr Addresser, bufSize int) *Client { c := &Client{} c.conn = connection.New() + c.ch = make(chan []byte, bufSize) go c.monitorMaster(addr) + go c.getRecords() return c } +func (c *Client) getRecords() { + for { + t, err := c.getTask() + if err != nil { + log.Println(err) + continue + } + + for _, chunk := range t.Chunks { + f, err := os.Open(chunk.Path) + if err != nil { + log.Println(err) + continue + } + + s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1) + for s.Scan() { + c.ch <- s.Record() + } + + err = f.Close() + if err != nil { + log.Println(err) + } + } + c.taskFinished(t.ID) + } +} + func (c *Client) monitorMaster(addr Addresser) { lastMaster := "" monitor := func() { @@ -69,14 +106,22 @@ func (c *Client) SetDataset(globPaths []string) error { return c.conn.Call("Service.SetDataset", globPaths, nil) } -// GetTask gets a new task from the master server. -func (c *Client) GetTask() (Task, error) { +// getTask gets a new task from the master server. +func (c *Client) getTask() (Task, error) { var t Task err := c.conn.Call("Service.GetTask", 0, &t) return t, err } // TaskFinished tells the master server a task is finished. -func (c *Client) TaskFinished(taskID int) error { +func (c *Client) taskFinished(taskID int) error { return c.conn.Call("Service.TaskFinished", taskID, nil) } + +// NextRecord returns next record in the dataset. +// +// NextRecord will block until next record is available. It is +// thread-safe. +func (c *Client) NextRecord() []byte { + return <-c.ch +} diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go new file mode 100644 index 0000000000..362668202a --- /dev/null +++ b/go/master/client_internal_test.go @@ -0,0 +1,120 @@ +package master + +import ( + "fmt" + "net" + "net/http" + "net/rpc" + "os" + "strconv" + "strings" + "testing" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/PaddlePaddle/Paddle/go/connection" + "github.com/PaddlePaddle/recordio" +) + +const ( + totalTask = 20 + chunkPerTask = 10 +) + +func init() { + log.SetLevel(log.ErrorLevel) +} + +type TestAddresser string + +func (a TestAddresser) Address() string { + return string(a) +} + +func TestGetFinishTask(t *testing.T) { + const path = "/tmp/master_client_test_0" + + l, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + + ss := strings.Split(l.Addr().String(), ":") + p, err := strconv.Atoi(ss[len(ss)-1]) + if err != nil { + panic(err) + } + + go func(l net.Listener) { + s := NewService(chunkPerTask, time.Second, 1) + server := rpc.NewServer() + err := server.Register(s) + if err != nil { + panic(err) + } + + mux := http.NewServeMux() + mux.Handle(rpc.DefaultRPCPath, server) + err = http.Serve(l, mux) + if err != nil { + panic(err) + } + }(l) + + f, err := os.Create(path) + if err != nil { + panic(err) + } + + for i := 0; i < totalTask*chunkPerTask; i++ { + w := recordio.NewWriter(f, -1, -1) + w.Write(nil) + // call Close to force RecordIO writing a chunk. + w.Close() + } + f.Close() + + c := &Client{} + c.conn = connection.New() + go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p))) + c.SetDataset([]string{path}) + + checkOnePass := func(i int) { + var tasks []Task + for idx := 0; idx < totalTask; idx++ { + task, err := c.getTask() + if err != nil { + t.Fatal(err, " pass:", i) + } + tasks = append(tasks, task) + } + + _, err = c.getTask() + if err == nil { + t.Fatal("Should get error. Pass:", i) + } + + err = c.taskFinished(tasks[0].ID) + if err != nil { + t.Fatal(err, "pass:", i) + } + tasks = tasks[1:] + task, err := c.getTask() + if err != nil { + t.Fatal(err) + } + tasks = append(tasks, task) + + for _, task := range tasks { + err = c.taskFinished(task.ID) + if err != nil { + t.Fatal(err, " pass:", i) + } + } + } + + for i := 0; i < 10; i++ { + checkOnePass(i) + } +} diff --git a/go/master/client_test.go b/go/master/client_test.go index df708ad791..2b3f873ecf 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -11,21 +11,15 @@ import ( "testing" "time" - log "github.com/sirupsen/logrus" - "github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/recordio" ) -const ( - totalTask = 20 - chunkPerTask = 10 -) - -var port int - -func init() { - log.SetLevel(log.ErrorLevel) +func TestNextRecord(t *testing.T) { + const ( + path = "/tmp/master_client_TestFull" + total = 50 + ) l, err := net.Listen("tcp", ":0") if err != nil { @@ -37,10 +31,9 @@ func init() { if err != nil { panic(err) } - port = p go func(l net.Listener) { - s := master.NewService(chunkPerTask, time.Second, 1) + s := master.NewService(10, time.Second, 1) server := rpc.NewServer() err := server.Register(s) if err != nil { @@ -54,67 +47,33 @@ func init() { panic(err) } }(l) -} -type addresser string - -func (a addresser) Address() string { - return string(a) -} - -func TestClientFull(t *testing.T) { - const p = "/tmp/master_client_test_0" - f, err := os.Create(p) + f, err := os.Create(path) if err != nil { panic(err) } - for i := 0; i < totalTask*chunkPerTask; i++ { - w := recordio.NewWriter(f, -1, -1) - w.Write(nil) - // call Close to force RecordIO writing a chunk. - w.Close() + w := recordio.NewWriter(f, -1, -1) + for i := 0; i < total; i++ { + w.Write([]byte{byte(i)}) } + w.Close() f.Close() - c := master.NewClient(addresser(fmt.Sprintf(":%d", port))) - c.SetDataset([]string{p}) + c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10) + c.SetDataset([]string{path}) - checkOnePass := func(i int) { - var tasks []master.Task - for i := 0; i < totalTask; i++ { - task, err := c.GetTask() - if err != nil { - t.Fatal(i, err) + for pass := 0; pass < 50; pass++ { + received := make(map[byte]bool) + for i := 0; i < total; i++ { + r := c.NextRecord() + if len(r) != 1 { + t.Fatal("Length should be 1.", r) } - tasks = append(tasks, task) - } - - _, err = c.GetTask() - if err == nil { - t.Fatal(i, "should get error.") - } - - err = c.TaskFinished(tasks[0].ID) - if err != nil { - t.Fatal(err) - } - tasks = tasks[1:] - task, err := c.GetTask() - if err != nil { - t.Fatal(err) - } - tasks = append(tasks, task) - - for _, task := range tasks { - err = c.TaskFinished(task.ID) - if err != nil { - t.Fatal(i, err) + if received[r[0]] { + t.Fatal("Received duplicate.", received, r) } + received[r[0]] = true } } - - for i := 0; i < 10; i++ { - checkOnePass(i) - } } diff --git a/go/master/service.go b/go/master/service.go index 1e2a34972b..2e165138fb 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -217,6 +217,16 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { } } +// must be called with lock held. +func (s *Service) logFields() log.Fields { + return log.Fields{ + "todoLen": len(s.taskQueues.Todo), + "pendingLen": len(s.taskQueues.Pending), + "doneLen": len(s.taskQueues.Done), + "failedLen": len(s.taskQueues.Failed), + } +} + // GetTask gets a new task from the service. func (s *Service) GetTask(dummy int, task *Task) error { select { @@ -230,7 +240,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { if len(s.taskQueues.Done) == 0 { if len(s.taskQueues.Pending) == 0 { err := errors.New("all task failed") - log.Warningln(err) + log.WithFields(s.logFields()).Warningln("All tasks failed.") return err } @@ -243,12 +253,12 @@ func (s *Service) GetTask(dummy int, task *Task) error { // in package. So we need to figure out a way // for client to check this error correctly. err := errors.New("no more available task") - log.Warningln(err) + log.WithFields(s.logFields()).Warningln("No more available task.") return err } s.taskQueues.Todo = s.taskQueues.Done s.taskQueues.Done = nil - log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.") + log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.") } t := s.taskQueues.Todo[0] @@ -261,7 +271,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { } *task = t.Task - log.Infof("Task #%d dispatched\n", task.ID) + log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID) time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch)) return nil @@ -276,12 +286,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { s.mu.Lock() defer s.mu.Unlock() - log.Infof("Task %d finished\n", taskID) - t, ok := s.taskQueues.Pending[taskID] if !ok { err := errors.New("pending task not found") - log.Warningln(err) + log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID) return err } @@ -290,8 +298,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { s.taskQueues.Done = append(s.taskQueues.Done, t) delete(s.taskQueues.Pending, taskID) + log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) + if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { - log.Infoln("No more todo and pending task, start a new pass.") + log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.") s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) s.taskQueues.Done = nil } -- GitLab