提交 025e7f9c 编写于 作者: H Helin Wang

implement basic master server

上级 d7b5a136
package main
import (
"flag"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/paddle/go/master"
"github.com/wangkuiyi/recordio"
)
const (
taskTimeoutDur = 20 * time.Minute
taskTimeoutMax = 3
)
func main() {
port := flag.Int("p", 0, "port of the master server")
dataset := flag.String("d", "", "dataset: comma separated path to RecordIO files")
faultTolerant := flag.Bool("fault-tolerance", false, "enable fault tolerance (requires etcd).")
flag.Parse()
if *dataset == "" {
panic("no dataset specified.")
}
if *faultTolerant {
panic("fault tolernat not implemented.")
}
var chunks []master.Chunk
paths := strings.Split(*dataset, ",")
idx := 0
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
panic(err)
}
index, err := recordio.LoadIndex(f)
if err != nil {
panic(err)
}
f.Close()
count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := master.Chunk{
Idx: idx,
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
}
s := master.NewService(chunks, taskTimeoutDur, taskTimeoutMax)
err := rpc.Register(s)
if err != nil {
panic(err)
}
rpc.HandleHTTP()
l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
if err != nil {
panic(err)
}
err = http.Serve(l, nil)
if err != nil {
panic(err)
}
}
...@@ -64,14 +64,14 @@ func partition(chunks []Chunk, targetTaskCount int) []taskEntry { ...@@ -64,14 +64,14 @@ func partition(chunks []Chunk, targetTaskCount int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(chunks []Chunk, timeoutDur time.Duration, timeoutMax int) (*Service, error) { func NewService(chunks []Chunk, timeoutDur time.Duration, timeoutMax int) *Service {
s := &Service{} s := &Service{}
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax s.timeoutMax = timeoutMax
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.taskQueues.Todo = partition(chunks, targetTaskCount) s.taskQueues.Todo = partition(chunks, targetTaskCount)
return s, nil return s
} }
// Chunk is a chunk of data consisted of several data instances. // Chunk is a chunk of data consisted of several data instances.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册