提交 fa5c3f1f 编写于 作者: H Helin Wang

implement master client, Go part

上级 91f82aba
package main
/*
typedef int paddle_master_client;
*/
import "C"
import (
"log"
"sync"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
)
var mu sync.Mutex
var handleMap = make(map[C.paddle_master_client]*master.Client)
var curHandle C.paddle_master_client
func add(c *master.Client) C.paddle_master_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
curHandle++
handleMap[client] = c
return client
}
func get(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}
func remove(client C.paddle_master_client) *master.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
delete(handleMap, client)
return h
}
type addresser string
func (a addresser) Address() string {
return string(a)
}
//paddle_new_master_client
func paddle_new_master_client(addr *C.char, buf_size C.int) C.paddle_master_client {
a := C.GoString(addr)
c := master.NewClient(addresser(a), int(buf_size))
return add(c)
}
//export paddle_new_etcd_master_client
func paddle_new_etcd_master_client(etcd_addr *C.char) C.paddle_master_client {
// TODO(helin): fault tolerant master client using etcd.
panic("not implemented.")
}
//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
var paths []string
for i := 0; i < int(size); i++ {
ptr := (**C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(path)) + uintptr(size)))
str := C.GoString(*ptr)
paths = append(paths, str)
}
err := c.SetDataset(paths)
if err != nil {
log.Println(err)
return -1
}
return 0
}
func main() {}
......@@ -2,9 +2,11 @@ package master
import (
"log"
"os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
// Addresser provide the address of the master server.
......@@ -15,16 +17,51 @@ type Addresser interface {
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan []byte
}
// NewClient creates a new Client.
func NewClient(addr Addresser) *Client {
//
// bufSize is the record buffer size. NextRecord will read from the
// buffer.
func NewClient(addr Addresser, bufSize int) *Client {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr)
go c.getRecords()
return c
}
func (c *Client) getRecords() {
for {
t, err := c.getTask()
if err != nil {
log.Println(err)
continue
}
for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Println(err)
continue
}
s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
for s.Scan() {
c.ch <- s.Record()
}
err = f.Close()
if err != nil {
log.Println(err)
}
}
c.taskFinished(t.ID)
}
}
func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
......@@ -69,14 +106,22 @@ func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
}
// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
return t, err
}
// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
func (c *Client) taskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}
// NextRecord returns next record in the dataset.
//
// NextRecord will block until next record is available. It is
// thread-safe.
func (c *Client) NextRecord() []byte {
return <-c.ch
}
package master
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
const (
totalTask = 20
chunkPerTask = 10
)
func init() {
log.SetLevel(log.ErrorLevel)
}
type TestAddresser string
func (a TestAddresser) Address() string {
return string(a)
}
func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0"
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
go func(l net.Listener) {
s := NewService(chunkPerTask, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
f, err := os.Create(path)
if err != nil {
panic(err)
}
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
// call Close to force RecordIO writing a chunk.
w.Close()
}
f.Close()
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p)))
c.SetDataset([]string{path})
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatal(err, " pass:", i)
}
tasks = append(tasks, task)
}
_, err = c.getTask()
if err == nil {
t.Fatal("Should get error. Pass:", i)
}
err = c.taskFinished(tasks[0].ID)
if err != nil {
t.Fatal(err, "pass:", i)
}
tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.taskFinished(task.ID)
if err != nil {
t.Fatal(err, " pass:", i)
}
}
}
for i := 0; i < 10; i++ {
checkOnePass(i)
}
}
......@@ -11,21 +11,15 @@ import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
const (
totalTask = 20
chunkPerTask = 10
)
var port int
func init() {
log.SetLevel(log.ErrorLevel)
func TestNextRecord(t *testing.T) {
const (
path = "/tmp/master_client_TestFull"
total = 50
)
l, err := net.Listen("tcp", ":0")
if err != nil {
......@@ -37,10 +31,9 @@ func init() {
if err != nil {
panic(err)
}
port = p
go func(l net.Listener) {
s := master.NewService(chunkPerTask, time.Second, 1)
s := master.NewService(10, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
......@@ -54,67 +47,33 @@ func init() {
panic(err)
}
}(l)
}
type addresser string
func (a addresser) Address() string {
return string(a)
}
func TestClientFull(t *testing.T) {
const p = "/tmp/master_client_test_0"
f, err := os.Create(p)
f, err := os.Create(path)
if err != nil {
panic(err)
}
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
// call Close to force RecordIO writing a chunk.
w.Close()
for i := 0; i < total; i++ {
w.Write([]byte{byte(i)})
}
w.Close()
f.Close()
c := master.NewClient(addresser(fmt.Sprintf(":%d", port)))
c.SetDataset([]string{p})
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10)
c.SetDataset([]string{path})
checkOnePass := func(i int) {
var tasks []master.Task
for i := 0; i < totalTask; i++ {
task, err := c.GetTask()
if err != nil {
t.Fatal(i, err)
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r := c.NextRecord()
if len(r) != 1 {
t.Fatal("Length should be 1.", r)
}
tasks = append(tasks, task)
if received[r[0]] {
t.Fatal("Received duplicate.", received, r)
}
_, err = c.GetTask()
if err == nil {
t.Fatal(i, "should get error.")
received[r[0]] = true
}
err = c.TaskFinished(tasks[0].ID)
if err != nil {
t.Fatal(err)
}
tasks = tasks[1:]
task, err := c.GetTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.TaskFinished(task.ID)
if err != nil {
t.Fatal(i, err)
}
}
}
for i := 0; i < 10; i++ {
checkOnePass(i)
}
}
......@@ -217,6 +217,16 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
}
}
// must be called with lock held.
func (s *Service) logFields() log.Fields {
return log.Fields{
"todoLen": len(s.taskQueues.Todo),
"pendingLen": len(s.taskQueues.Pending),
"doneLen": len(s.taskQueues.Done),
"failedLen": len(s.taskQueues.Failed),
}
}
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
select {
......@@ -230,7 +240,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
err := errors.New("all task failed")
log.Warningln(err)
log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err
}
......@@ -243,12 +253,12 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.Warningln(err)
log.WithFields(s.logFields()).Warningln("No more available task.")
return err
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
}
t := s.taskQueues.Todo[0]
......@@ -261,7 +271,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
*task = t.Task
log.Infof("Task #%d dispatched\n", task.ID)
log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
return nil
......@@ -276,12 +286,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.mu.Lock()
defer s.mu.Unlock()
log.Infof("Task %d finished\n", taskID)
t, ok := s.taskQueues.Pending[taskID]
if !ok {
err := errors.New("pending task not found")
log.Warningln(err)
log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
return err
}
......@@ -290,8 +298,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.Infoln("No more todo and pending task, start a new pass.")
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册