提交 72a73ab6 编写于 作者: H Helin Wang

implement master server client, RPC part.

上级 f05649af
...@@ -50,7 +50,6 @@ func main() { ...@@ -50,7 +50,6 @@ func main() {
panic("no valid datset specified.") panic("no valid datset specified.")
} }
idx := 0
for _, path := range paths { for _, path := range paths {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
...@@ -66,7 +65,6 @@ func main() { ...@@ -66,7 +65,6 @@ func main() {
count := index.NumChunks() count := index.NumChunks()
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
chunk := master.Chunk{ chunk := master.Chunk{
Idx: idx,
Path: path, Path: path,
Index: *index.ChunkIndex(i), Index: *index.ChunkIndex(i),
} }
......
...@@ -21,6 +21,18 @@ func New() *Conn { ...@@ -21,6 +21,18 @@ func New() *Conn {
return c return c
} }
// Close closes the connection.
func (c *Conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return nil
}
return c.client.Close()
}
// Connect connects the connection to a address. // Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error { func (c *Conn) Connect(addr string) error {
c.mu.Lock() c.mu.Lock()
...@@ -56,6 +68,9 @@ func (c *Conn) Connect(addr string) error { ...@@ -56,6 +68,9 @@ func (c *Conn) Connect(addr string) error {
return nil return nil
} }
// TODO(helin): refactor Call to be able to perform given retry
// policy.
// Call make a RPC call. // Call make a RPC call.
// //
// Call will be blocked until the connection to remote RPC service // Call will be blocked until the connection to remote RPC service
......
package master
import (
"log"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
)
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
}
// NewClient creates a new Client.
func NewClient(addr Addresser) *Client {
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(addr)
return c
}
func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
curMaster := addr.Address()
if curMaster != lastMaster {
if curMaster == "" {
err := c.conn.Close()
if err != nil {
log.Println(err)
}
} else {
err := c.conn.Connect(curMaster)
if err != nil {
log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curMaster = lastMaster
}
}
}
lastMaster = curMaster
}
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
}
// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
var dummy int
var t Task
err := c.conn.Call("Service.GetTask", dummy, &t)
return t, err
}
// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
var dummy int
return c.conn.Call("Service.TaskFinished", taskID, &dummy)
}
package master_test
import (
"fmt"
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
)
const (
totalTask = 20
chunkPerTask = 10
)
var port int
func init() {
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
port = p
go func(l net.Listener) {
chunks := make([]master.Chunk, totalTask)
s := master.NewService(chunks, chunkPerTask, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
}
type addresser string
func (a addresser) Address() string {
return string(a)
}
func TestClientFull(t *testing.T) {
c := master.NewClient(addresser(fmt.Sprintf(":%d", port)))
for i := 0; i < 5*totalTask/chunkPerTask; i++ {
task, err := c.GetTask()
if err != nil {
panic(err)
}
if len(task.Chunks) != chunkPerTask {
t.Fatal("wrong number of chunk per task", len(task.Chunks))
}
err = c.TaskFinished(task.ID)
if err != nil {
panic(err)
}
}
}
...@@ -75,9 +75,8 @@ func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, tim ...@@ -75,9 +75,8 @@ func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, tim
// Chunk is a chunk of data consisted of several data instances. // Chunk is a chunk of data consisted of several data instances.
type Chunk struct { type Chunk struct {
Idx int // index of the chunk within the file
Path string Path string
Index recordio.Index // block index Index recordio.Index // chunk index
} }
// Task is the basic unit of data instances assigned to trainers. // Task is the basic unit of data instances assigned to trainers.
...@@ -123,6 +122,8 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -123,6 +122,8 @@ func (s *Service) GetTask(dummy int, task *Task) error {
return err return err
} }
*task = t.Task
time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() { time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() {
return func() { return func() {
s.mu.Lock() s.mu.Lock()
...@@ -174,5 +175,11 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error { ...@@ -174,5 +175,11 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
t.NumTimeout = 0 t.NumTimeout = 0
s.taskQueues.Done = append(s.taskQueues.Done, t) s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
if len(s.taskQueues.Todo) == 0 {
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
}
return s.snapshot() return s.snapshot()
} }
...@@ -47,7 +47,7 @@ func NewClient(l Lister, pserverNum int, sel Selector) *Client { ...@@ -47,7 +47,7 @@ func NewClient(l Lister, pserverNum int, sel Selector) *Client {
// monitorPservers monitors pserver addresses, and updates connection // monitorPservers monitors pserver addresses, and updates connection
// when the address changes. // when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) { func (c *Client) monitorPservers(l Lister, pserverNum int) {
knownServers := make([]Server, pserverNum) lastServers := make([]Server, pserverNum)
ticker := time.NewTicker(10 * time.Second) ticker := time.NewTicker(10 * time.Second)
monitor := func() { monitor := func() {
curServers := make([]Server, pserverNum) curServers := make([]Server, pserverNum)
...@@ -56,8 +56,17 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -56,8 +56,17 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
curServers[l.Index] = l curServers[l.Index] = l
} }
for i := range knownServers { for i := range lastServers {
if knownServers[i].Addr != curServers[i].Addr { if lastServers[i].Addr != curServers[i].Addr {
if curServers[i].Addr == "" {
err := c.pservers[i].Close()
if err != nil {
log.Println(err)
}
continue
}
err := c.pservers[i].Connect(curServers[i].Addr) err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
...@@ -65,12 +74,12 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -65,12 +74,12 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// connect to addr failed, set // connect to addr failed, set
// to last known addr in order // to last known addr in order
// to retry next time. // to retry next time.
curServers[i].Addr = knownServers[i].Addr curServers[i].Addr = lastServers[i].Addr
} }
} }
} }
knownServers = curServers lastServers = curServers
} }
monitor() monitor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册