diff --git a/paddle/go/cmd/master/master.go b/paddle/go/cmd/master/master.go index 16052fd75c78d3f42704af3cd800ca0fd5efd7b2..ef1f87c2dd53b701810c82ae90eaf3f94ea15e47 100644 --- a/paddle/go/cmd/master/master.go +++ b/paddle/go/cmd/master/master.go @@ -1,40 +1,55 @@ package main import ( - "flag" + "fmt" "net" "net/http" "net/rpc" "os" + "path/filepath" "strconv" "strings" "time" + "github.com/namsral/flag" + "github.com/PaddlePaddle/Paddle/paddle/go/master" "github.com/PaddlePaddle/Paddle/paddle/go/recordio" ) -const ( - taskTimeoutDur = 20 * time.Minute - taskTimeoutMax = 3 -) - func main() { - port := flag.Int("p", 0, "port of the master server") - dataset := flag.String("d", "", "dataset: comma separated path to RecordIO files") - faultTolerant := flag.Bool("fault-tolerance", false, "enable fault tolerance (requires etcd).") + port := flag.Int("port", 8080, "port of the master server.") + dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.") + faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") + taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") + taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") + chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") flag.Parse() if *dataset == "" { panic("no dataset specified.") } - if *faultTolerant { - panic("fault tolernat not implemented.") + if *faultTolerance { + panic("fault tolernance not implemented.") } var chunks []master.Chunk - paths := strings.Split(*dataset, ",") + var paths []string + ss := strings.Split(*dataset, ",") + fmt.Println(ss) + for _, s := range ss { + match, err := filepath.Glob(s) + if err != nil { + panic(err) + } + paths = append(paths, match...) + } + + if len(paths) == 0 { + panic("no valid datset specified.") + } + idx := 0 for _, path := range paths { f, err := os.Open(path) @@ -59,7 +74,7 @@ func main() { } } - s := master.NewService(chunks, taskTimeoutDur, taskTimeoutMax) + s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) err := rpc.Register(s) if err != nil { panic(err) diff --git a/paddle/go/cmd/pserver/pserver.go b/paddle/go/cmd/pserver/pserver.go index 41417875fb98aca3f2181841b28f7b220e948618..bd4bfc7028302df1c3e6ecd3cc9ebb11b158df01 100644 --- a/paddle/go/cmd/pserver/pserver.go +++ b/paddle/go/cmd/pserver/pserver.go @@ -1,17 +1,18 @@ package main import ( - "flag" "net" "net/http" "net/rpc" "strconv" + "github.com/namsral/flag" + "github.com/PaddlePaddle/Paddle/paddle/go/pserver" ) func main() { - port := flag.Int("p", 0, "port of the pserver") + port := flag.Int("port", 0, "port of the pserver") flag.Parse() s := pserver.NewService() diff --git a/paddle/go/master/service.go b/paddle/go/master/service.go index cf15f28cc7403cb01d876f8105dc29412e3fa231..75266482870c448fcde7359640bc4773c200fecb 100644 --- a/paddle/go/master/service.go +++ b/paddle/go/master/service.go @@ -34,17 +34,16 @@ func Recover() (*Service, error) { return nil, nil } -func partition(chunks []Chunk, targetTaskCount int) []taskEntry { +func partition(chunks []Chunk, chunksPerTask int) []taskEntry { id := 0 - chunkPerTask := len(chunks) / targetTaskCount - if chunkPerTask <= 0 { - chunkPerTask = 1 + if chunksPerTask <= 0 { + chunksPerTask = 1 } var result []taskEntry var cur taskEntry for i, c := range chunks { - if i%chunkPerTask == 0 && len(cur.Task.Chunks) > 0 { + if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { cur.Task.ID = id id++ result = append(result, cur) @@ -64,13 +63,13 @@ func partition(chunks []Chunk, targetTaskCount int) []taskEntry { } // NewService creates a new service. -func NewService(chunks []Chunk, timeoutDur time.Duration, timeoutMax int) *Service { +func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { s := &Service{} s.timeoutDur = timeoutDur s.timeoutMax = timeoutMax s.taskQueues = taskQueues{} s.taskQueues.Pending = make(map[int]taskEntry) - s.taskQueues.Todo = partition(chunks, targetTaskCount) + s.taskQueues.Todo = partition(chunks, chunksPerTask) return s } diff --git a/paddle/go/master/service_internal_test.go b/paddle/go/master/service_internal_test.go index 1e6197d2418084d3e4d505eda9423771c4369102..bc435b505c014ca13ed5fc16b33a21ebb089a3b7 100644 --- a/paddle/go/master/service_internal_test.go +++ b/paddle/go/master/service_internal_test.go @@ -4,18 +4,23 @@ import "testing" func TestPartitionCount(t *testing.T) { cs := make([]Chunk, 100) - ts := partition(cs, 20) + ts := partition(cs, 5) if len(ts) != 20 { t.Error(len(ts)) } cs = make([]Chunk, 101) - ts = partition(cs, 20) + ts = partition(cs, 5) if len(ts) != 21 { t.Error(len(ts)) } - ts = partition(cs, 200) + ts = partition(cs, 1) + if len(ts) != 101 { + t.Error(len(ts)) + } + + ts = partition(cs, 0) if len(ts) != 101 { t.Error(len(ts)) }