提交 9b11e17d 编写于 作者: H Helin Wang

fix according to comments

上级 6ce7c8bc
package main package main
import ( import (
"flag" "fmt"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os" "os"
"path/filepath"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/paddle/go/master" "github.com/PaddlePaddle/Paddle/paddle/go/master"
"github.com/PaddlePaddle/Paddle/paddle/go/recordio" "github.com/PaddlePaddle/Paddle/paddle/go/recordio"
) )
const (
taskTimeoutDur = 20 * time.Minute
taskTimeoutMax = 3
)
func main() { func main() {
port := flag.Int("p", 0, "port of the master server") port := flag.Int("port", 8080, "port of the master server.")
dataset := flag.String("d", "", "dataset: comma separated path to RecordIO files") dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")
faultTolerant := flag.Bool("fault-tolerance", false, "enable fault tolerance (requires etcd).") faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
flag.Parse() flag.Parse()
if *dataset == "" { if *dataset == "" {
panic("no dataset specified.") panic("no dataset specified.")
} }
if *faultTolerant { if *faultTolerance {
panic("fault tolernat not implemented.") panic("fault tolernance not implemented.")
} }
var chunks []master.Chunk var chunks []master.Chunk
paths := strings.Split(*dataset, ",") var paths []string
ss := strings.Split(*dataset, ",")
fmt.Println(ss)
for _, s := range ss {
match, err := filepath.Glob(s)
if err != nil {
panic(err)
}
paths = append(paths, match...)
}
if len(paths) == 0 {
panic("no valid datset specified.")
}
idx := 0 idx := 0
for _, path := range paths { for _, path := range paths {
f, err := os.Open(path) f, err := os.Open(path)
...@@ -59,7 +74,7 @@ func main() { ...@@ -59,7 +74,7 @@ func main() {
} }
} }
s := master.NewService(chunks, taskTimeoutDur, taskTimeoutMax) s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
err := rpc.Register(s) err := rpc.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
......
package main package main
import ( import (
"flag"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver" "github.com/PaddlePaddle/Paddle/paddle/go/pserver"
) )
func main() { func main() {
port := flag.Int("p", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
flag.Parse() flag.Parse()
s := pserver.NewService() s := pserver.NewService()
......
...@@ -34,17 +34,16 @@ func Recover() (*Service, error) { ...@@ -34,17 +34,16 @@ func Recover() (*Service, error) {
return nil, nil return nil, nil
} }
func partition(chunks []Chunk, targetTaskCount int) []taskEntry { func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0 id := 0
chunkPerTask := len(chunks) / targetTaskCount if chunksPerTask <= 0 {
if chunkPerTask <= 0 { chunksPerTask = 1
chunkPerTask = 1
} }
var result []taskEntry var result []taskEntry
var cur taskEntry var cur taskEntry
for i, c := range chunks { for i, c := range chunks {
if i%chunkPerTask == 0 && len(cur.Task.Chunks) > 0 { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.ID = id
id++ id++
result = append(result, cur) result = append(result, cur)
...@@ -64,13 +63,13 @@ func partition(chunks []Chunk, targetTaskCount int) []taskEntry { ...@@ -64,13 +63,13 @@ func partition(chunks []Chunk, targetTaskCount int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(chunks []Chunk, timeoutDur time.Duration, timeoutMax int) *Service { func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service {
s := &Service{} s := &Service{}
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax s.timeoutMax = timeoutMax
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.taskQueues.Todo = partition(chunks, targetTaskCount) s.taskQueues.Todo = partition(chunks, chunksPerTask)
return s return s
} }
......
...@@ -4,18 +4,23 @@ import "testing" ...@@ -4,18 +4,23 @@ import "testing"
func TestPartitionCount(t *testing.T) { func TestPartitionCount(t *testing.T) {
cs := make([]Chunk, 100) cs := make([]Chunk, 100)
ts := partition(cs, 20) ts := partition(cs, 5)
if len(ts) != 20 { if len(ts) != 20 {
t.Error(len(ts)) t.Error(len(ts))
} }
cs = make([]Chunk, 101) cs = make([]Chunk, 101)
ts = partition(cs, 20) ts = partition(cs, 5)
if len(ts) != 21 { if len(ts) != 21 {
t.Error(len(ts)) t.Error(len(ts))
} }
ts = partition(cs, 200) ts = partition(cs, 1)
if len(ts) != 101 {
t.Error(len(ts))
}
ts = partition(cs, 0)
if len(ts) != 101 { if len(ts) != 101 {
t.Error(len(ts)) t.Error(len(ts))
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册