提交 54e8263c 编写于 作者: H Helin Wang

implement master server client, remove unnecessary dummy variable

上级 72a73ab6
package main
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
func main() {
port := flag.Int("port", 8080, "port of the master server.")
dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
flag.Parse()
if *dataset == "" {
panic("no dataset specified.")
}
if *faultTolerance {
panic("fault tolernance not implemented.")
}
var chunks []master.Chunk
var paths []string
ss := strings.Split(*dataset, ",")
fmt.Println(ss)
for _, s := range ss {
match, err := filepath.Glob(s)
if err != nil {
panic(err)
}
paths = append(paths, match...)
}
if len(paths) == 0 {
panic("no valid datset specified.")
}
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
panic(err)
}
index, err := recordio.LoadIndex(f)
if err != nil {
panic(err)
}
f.Close()
count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := master.Chunk{
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
}
s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
err := rpc.Register(s)
if err != nil {
panic(err)
......
......@@ -59,16 +59,22 @@ func (c *Client) monitorMaster(addr Addresser) {
}
}
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
}
// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
var dummy int
var t Task
err := c.conn.Call("Service.GetTask", dummy, &t)
err := c.conn.Call("Service.GetTask", 0, &t)
return t, err
}
// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
var dummy int
return c.conn.Call("Service.TaskFinished", taskID, &dummy)
return c.conn.Call("Service.TaskFinished", taskID, nil)
}
......@@ -5,12 +5,14 @@ import (
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
const (
......@@ -34,8 +36,7 @@ func init() {
port = p
go func(l net.Listener) {
chunks := make([]master.Chunk, totalTask)
s := master.NewService(chunks, chunkPerTask, time.Second, 1)
s := master.NewService(chunkPerTask, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
......@@ -58,21 +59,47 @@ func (a addresser) Address() string {
}
func TestClientFull(t *testing.T) {
const p = "/tmp/master_client_test_0"
f, err := os.Create(p)
if err != nil {
panic(err)
}
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
// call Close to force RecordIO writing a chunk.
w.Close()
}
f.Close()
c := master.NewClient(addresser(fmt.Sprintf(":%d", port)))
c.SetDataset([]string{p})
for i := 0; i < 5*totalTask/chunkPerTask; i++ {
checkOnePass := func(i int) {
var tasks []master.Task
for i := 0; i < totalTask; i++ {
task, err := c.GetTask()
if err != nil {
panic(err)
t.Fatal(i, err)
}
tasks = append(tasks, task)
}
if len(task.Chunks) != chunkPerTask {
t.Fatal("wrong number of chunk per task", len(task.Chunks))
_, err = c.GetTask()
if err == nil {
t.Fatal(i, "should get error.")
}
for _, task := range tasks {
err = c.TaskFinished(task.ID)
if err != nil {
panic(err)
t.Fatal(i, err)
}
}
}
for i := 0; i < 10; i++ {
checkOnePass(i)
}
}
......@@ -3,6 +3,8 @@ package master
import (
"errors"
"log"
"os"
"path/filepath"
"sync"
"time"
......@@ -13,18 +15,15 @@ const (
targetTaskCount = 300
)
// errors
var (
ErrNoMoreTask = errors.New("no more task for current pass")
ErrPendingTaskNotFound = errors.New("pending task not found")
)
// Service is the master server service.
type Service struct {
chunksPerTask int
timeoutDur time.Duration
timeoutMax int
ready chan struct{}
mu sync.Mutex
initBegan bool
taskQueues taskQueues
}
......@@ -63,13 +62,14 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
}
// NewService creates a new service.
func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service {
func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service {
s := &Service{}
s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax
s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry)
s.taskQueues.Todo = partition(chunks, chunksPerTask)
s.ready = make(chan struct{})
return s
}
......@@ -104,13 +104,102 @@ func (s *Service) snapshot() error {
return nil
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.initBegan {
// SetDataset already called. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return nil
}
s.initBegan = true
var chunks []Chunk
var paths []string
for _, s := range globPaths {
match, err := filepath.Glob(s)
if err != nil {
panic(err)
}
paths = append(paths, match...)
}
if len(paths) == 0 {
return errors.New("no valid datset specified")
}
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
panic(err)
}
index, err := recordio.LoadIndex(f)
if err != nil {
return err
}
err = f.Close()
if err != nil {
return err
}
count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := Chunk{
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
}
s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
err := s.snapshot()
if err != nil {
return err
}
close(s.ready)
return nil
}
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
if len(s.taskQueues.Todo) == 0 {
return ErrNoMoreTask
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
return errors.New("all task failed")
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// erros like io.EOF. Because interface don't
// have same dynamic value when in different
// process.
return errors.New("no more available task")
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Todo = nil
}
t := s.taskQueues.Todo[0]
......@@ -163,12 +252,16 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
t, ok := s.taskQueues.Pending[taskID]
if !ok {
return ErrPendingTaskNotFound
return errors.New("pending task not found")
}
// task finished, reset timeout
......@@ -176,8 +269,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID)
if len(s.taskQueues.Todo) == 0 {
s.taskQueues.Todo = s.taskQueues.Done
if len(s.taskQueues.Pending) == 0 {
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
}
......
......@@ -102,16 +102,14 @@ func (c *Client) BeginInitParams() bool {
// InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
var dummy int
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy)
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
}
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
func (c *Client) FinishInitParams() error {
for _, p := range c.pservers {
var dummy int
err := p.Call("Service.FinishInitParams", dummy, &dummy)
err := p.Call("Service.FinishInitParams", 0, nil)
if err != nil {
return err
}
......@@ -125,8 +123,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g Gradient) {
var dummy int
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy)
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
errCh <- err
}(g)
}
......@@ -205,8 +202,7 @@ func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
var dummy int
err := p.Call("Service.Save", path, &dummy)
err := p.Call("Service.Save", path, nil)
errCh <- err
}
......
......@@ -15,8 +15,7 @@ func TestFull(t *testing.T) {
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
var dummy int
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, nil)
if err != nil {
t.FailNow()
}
......@@ -25,12 +24,12 @@ func TestFull(t *testing.T) {
p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, &dummy)
err = s.InitParam(pserver.ParameterWithConfig{p1, nil}, nil)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
}
......@@ -46,11 +45,11 @@ func TestFull(t *testing.T) {
}
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, &dummy)
err = s.SendGrad(g1, nil)
if err != nil {
t.FailNow()
}
err = s.SendGrad(g2, &dummy)
err = s.SendGrad(g2, nil)
if err != nil {
t.FailNow()
......@@ -74,23 +73,21 @@ func TestFull(t *testing.T) {
func TestMultipleInit(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.FinishInitParams(0, &dummy)
err := s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
if err.Error() != pserver.AlreadyInitialized {
err = s.FinishInitParams(0, nil)
if err != pserver.ErrAlreadyInitialized {
t.FailNow()
}
}
func TestUninitialized(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.SendGrad(pserver.Gradient{}, &dummy)
if err.Error() != pserver.Uninitialized {
err := s.SendGrad(pserver.Gradient{}, nil)
if err != pserver.ErrUninitialized {
t.FailNow()
}
}
......@@ -112,8 +109,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Add(1)
go func() {
var dummy int
err := s.Save("", &dummy)
err := s.Save("", nil)
if err != nil {
t.FailNow()
}
......@@ -134,13 +130,12 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
var dummy int
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, nil)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册