From 72a73ab6d2139fae73dc922505acad6d8aa41ec4 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Thu, 8 Jun 2017 22:18:36 +0000 Subject: [PATCH] implement master server client, RPC part. --- go/cmd/master/master.go | 2 -- go/connection/conn.go | 15 ++++++++ go/master/client.go | 74 ++++++++++++++++++++++++++++++++++++++ go/master/client_test.go | 78 ++++++++++++++++++++++++++++++++++++++++ go/master/service.go | 11 ++++-- go/pserver/client.go | 19 +++++++--- 6 files changed, 190 insertions(+), 9 deletions(-) create mode 100644 go/master/client.go create mode 100644 go/master/client_test.go diff --git a/go/cmd/master/master.go b/go/cmd/master/master.go index d1f3d7d76c4..65548b7b684 100644 --- a/go/cmd/master/master.go +++ b/go/cmd/master/master.go @@ -50,7 +50,6 @@ func main() { panic("no valid datset specified.") } - idx := 0 for _, path := range paths { f, err := os.Open(path) if err != nil { @@ -66,7 +65,6 @@ func main() { count := index.NumChunks() for i := 0; i < count; i++ { chunk := master.Chunk{ - Idx: idx, Path: path, Index: *index.ChunkIndex(i), } diff --git a/go/connection/conn.go b/go/connection/conn.go index 1c04f117254..0bab2def1d9 100644 --- a/go/connection/conn.go +++ b/go/connection/conn.go @@ -21,6 +21,18 @@ func New() *Conn { return c } +// Close closes the connection. +func (c *Conn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.client == nil { + return nil + } + + return c.client.Close() +} + // Connect connects the connection to a address. func (c *Conn) Connect(addr string) error { c.mu.Lock() @@ -56,6 +68,9 @@ func (c *Conn) Connect(addr string) error { return nil } +// TODO(helin): refactor Call to be able to perform given retry +// policy. + // Call make a RPC call. // // Call will be blocked until the connection to remote RPC service diff --git a/go/master/client.go b/go/master/client.go new file mode 100644 index 00000000000..23ef18f9e27 --- /dev/null +++ b/go/master/client.go @@ -0,0 +1,74 @@ +package master + +import ( + "log" + "time" + + "github.com/PaddlePaddle/Paddle/go/connection" +) + +// Addresser provide the address of the master server. +type Addresser interface { + Address() string +} + +// Client is the client of the master server. +type Client struct { + conn *connection.Conn +} + +// NewClient creates a new Client. +func NewClient(addr Addresser) *Client { + c := &Client{} + c.conn = connection.New() + go c.monitorMaster(addr) + return c +} + +func (c *Client) monitorMaster(addr Addresser) { + lastMaster := "" + monitor := func() { + curMaster := addr.Address() + if curMaster != lastMaster { + if curMaster == "" { + err := c.conn.Close() + if err != nil { + log.Println(err) + } + } else { + err := c.conn.Connect(curMaster) + if err != nil { + log.Println(err) + + // connect to addr failed, set + // to last known addr in order + // to retry next time. + curMaster = lastMaster + } + + } + } + + lastMaster = curMaster + } + + monitor() + ticker := time.NewTicker(10 * time.Second) + for _ = range ticker.C { + monitor() + } +} + +// GetTask gets a new task from the master server. +func (c *Client) GetTask() (Task, error) { + var dummy int + var t Task + err := c.conn.Call("Service.GetTask", dummy, &t) + return t, err +} + +// TaskFinished tells the master server a task is finished. +func (c *Client) TaskFinished(taskID int) error { + var dummy int + return c.conn.Call("Service.TaskFinished", taskID, &dummy) +} diff --git a/go/master/client_test.go b/go/master/client_test.go new file mode 100644 index 00000000000..4603bdc4d6b --- /dev/null +++ b/go/master/client_test.go @@ -0,0 +1,78 @@ +package master_test + +import ( + "fmt" + "net" + "net/http" + "net/rpc" + "strconv" + "strings" + "testing" + "time" + + "github.com/PaddlePaddle/Paddle/go/master" +) + +const ( + totalTask = 20 + chunkPerTask = 10 +) + +var port int + +func init() { + l, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + + ss := strings.Split(l.Addr().String(), ":") + p, err := strconv.Atoi(ss[len(ss)-1]) + if err != nil { + panic(err) + } + port = p + + go func(l net.Listener) { + chunks := make([]master.Chunk, totalTask) + s := master.NewService(chunks, chunkPerTask, time.Second, 1) + server := rpc.NewServer() + err := server.Register(s) + if err != nil { + panic(err) + } + + mux := http.NewServeMux() + mux.Handle(rpc.DefaultRPCPath, server) + err = http.Serve(l, mux) + if err != nil { + panic(err) + } + }(l) +} + +type addresser string + +func (a addresser) Address() string { + return string(a) +} + +func TestClientFull(t *testing.T) { + c := master.NewClient(addresser(fmt.Sprintf(":%d", port))) + + for i := 0; i < 5*totalTask/chunkPerTask; i++ { + task, err := c.GetTask() + if err != nil { + panic(err) + } + + if len(task.Chunks) != chunkPerTask { + t.Fatal("wrong number of chunk per task", len(task.Chunks)) + } + + err = c.TaskFinished(task.ID) + if err != nil { + panic(err) + } + } +} diff --git a/go/master/service.go b/go/master/service.go index ab17a62f385..8d6bbecc497 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -75,9 +75,8 @@ func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, tim // Chunk is a chunk of data consisted of several data instances. type Chunk struct { - Idx int // index of the chunk within the file Path string - Index recordio.Index // block index + Index recordio.Index // chunk index } // Task is the basic unit of data instances assigned to trainers. @@ -123,6 +122,8 @@ func (s *Service) GetTask(dummy int, task *Task) error { return err } + *task = t.Task + time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() { return func() { s.mu.Lock() @@ -174,5 +175,11 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { t.NumTimeout = 0 s.taskQueues.Done = append(s.taskQueues.Done, t) delete(s.taskQueues.Pending, taskID) + + if len(s.taskQueues.Todo) == 0 { + s.taskQueues.Todo = s.taskQueues.Done + s.taskQueues.Done = nil + } + return s.snapshot() } diff --git a/go/pserver/client.go b/go/pserver/client.go index 4f35141a9f7..7930f012c36 100644 --- a/go/pserver/client.go +++ b/go/pserver/client.go @@ -47,7 +47,7 @@ func NewClient(l Lister, pserverNum int, sel Selector) *Client { // monitorPservers monitors pserver addresses, and updates connection // when the address changes. func (c *Client) monitorPservers(l Lister, pserverNum int) { - knownServers := make([]Server, pserverNum) + lastServers := make([]Server, pserverNum) ticker := time.NewTicker(10 * time.Second) monitor := func() { curServers := make([]Server, pserverNum) @@ -56,8 +56,17 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { curServers[l.Index] = l } - for i := range knownServers { - if knownServers[i].Addr != curServers[i].Addr { + for i := range lastServers { + if lastServers[i].Addr != curServers[i].Addr { + if curServers[i].Addr == "" { + err := c.pservers[i].Close() + if err != nil { + log.Println(err) + } + + continue + } + err := c.pservers[i].Connect(curServers[i].Addr) if err != nil { log.Println(err) @@ -65,12 +74,12 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { // connect to addr failed, set // to last known addr in order // to retry next time. - curServers[i].Addr = knownServers[i].Addr + curServers[i].Addr = lastServers[i].Addr } } } - knownServers = curServers + lastServers = curServers } monitor() -- GitLab