master.go 1.4 KB
Newer Older
H
Helin Wang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
package main

import (
	"flag"
	"net"
	"net/http"
	"net/rpc"
	"os"
	"strconv"
	"strings"
	"time"

	"github.com/PaddlePaddle/Paddle/paddle/go/master"
	"github.com/wangkuiyi/recordio"
)

const (
	taskTimeoutDur = 20 * time.Minute
	taskTimeoutMax = 3
)

func main() {
	port := flag.Int("p", 0, "port of the master server")
	dataset := flag.String("d", "", "dataset: comma separated path to RecordIO files")
	faultTolerant := flag.Bool("fault-tolerance", false, "enable fault tolerance (requires etcd).")
	flag.Parse()

	if *dataset == "" {
		panic("no dataset specified.")
	}

	if *faultTolerant {
		panic("fault tolernat not implemented.")
	}

	var chunks []master.Chunk
	paths := strings.Split(*dataset, ",")
	idx := 0
	for _, path := range paths {
		f, err := os.Open(path)
		if err != nil {
			panic(err)
		}

		index, err := recordio.LoadIndex(f)
		if err != nil {
			panic(err)
		}
		f.Close()

		count := index.NumChunks()
		for i := 0; i < count; i++ {
			chunk := master.Chunk{
				Idx:   idx,
				Path:  path,
				Index: *index.ChunkIndex(i),
			}
			chunks = append(chunks, chunk)
		}
	}

	s := master.NewService(chunks, taskTimeoutDur, taskTimeoutMax)
	err := rpc.Register(s)
	if err != nil {
		panic(err)
	}

	rpc.HandleHTTP()
	l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
	if err != nil {
		panic(err)
	}

	err = http.Serve(l, nil)
	if err != nil {
		panic(err)
	}
}