master.go 1.9 KB
Newer Older
H
Helin Wang 已提交
1 2 3
package main

import (
H
Helin Wang 已提交
4
	"fmt"
H
Helin Wang 已提交
5 6 7 8
	"net"
	"net/http"
	"net/rpc"
	"os"
H
Helin Wang 已提交
9
	"path/filepath"
H
Helin Wang 已提交
10 11 12 13
	"strconv"
	"strings"
	"time"

H
Helin Wang 已提交
14 15
	"github.com/namsral/flag"

16 17
	"github.com/PaddlePaddle/Paddle/go/master"
	"github.com/PaddlePaddle/Paddle/go/recordio"
H
Helin Wang 已提交
18 19 20
)

func main() {
H
Helin Wang 已提交
21 22 23 24 25 26
	port := flag.Int("port", 8080, "port of the master server.")
	dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")
	faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
	taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
	taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
	chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
H
Helin Wang 已提交
27 28 29 30 31 32
	flag.Parse()

	if *dataset == "" {
		panic("no dataset specified.")
	}

H
Helin Wang 已提交
33 34
	if *faultTolerance {
		panic("fault tolernance not implemented.")
H
Helin Wang 已提交
35 36 37
	}

	var chunks []master.Chunk
H
Helin Wang 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
	var paths []string
	ss := strings.Split(*dataset, ",")
	fmt.Println(ss)
	for _, s := range ss {
		match, err := filepath.Glob(s)
		if err != nil {
			panic(err)
		}
		paths = append(paths, match...)
	}

	if len(paths) == 0 {
		panic("no valid datset specified.")
	}

H
Helin Wang 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
	idx := 0
	for _, path := range paths {
		f, err := os.Open(path)
		if err != nil {
			panic(err)
		}

		index, err := recordio.LoadIndex(f)
		if err != nil {
			panic(err)
		}
		f.Close()

		count := index.NumChunks()
		for i := 0; i < count; i++ {
			chunk := master.Chunk{
				Idx:   idx,
				Path:  path,
				Index: *index.ChunkIndex(i),
			}
			chunks = append(chunks, chunk)
		}
	}

H
Helin Wang 已提交
77
	s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
H
Helin Wang 已提交
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
	err := rpc.Register(s)
	if err != nil {
		panic(err)
	}

	rpc.HandleHTTP()
	l, err := net.Listen("tcp", ":"+strconv.Itoa(*port))
	if err != nil {
		panic(err)
	}

	err = http.Serve(l, nil)
	if err != nil {
		panic(err)
	}
}