提交 41af738a 编写于 作者: H Helin Wang

fix according to comments

上级 54e8263c
...@@ -62,6 +62,7 @@ func (c *Conn) Connect(addr string) error { ...@@ -62,6 +62,7 @@ func (c *Conn) Connect(addr string) error {
c.waitConn = nil c.waitConn = nil
} }
} else { } else {
client.Close()
return errors.New("client already set from a concurrent goroutine") return errors.New("client already set from a concurrent goroutine")
} }
......
...@@ -11,10 +11,6 @@ import ( ...@@ -11,10 +11,6 @@ import (
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
) )
const (
targetTaskCount = 300
)
// Service is the master server service. // Service is the master server service.
type Service struct { type Service struct {
chunksPerTask int chunksPerTask int
...@@ -23,7 +19,7 @@ type Service struct { ...@@ -23,7 +19,7 @@ type Service struct {
ready chan struct{} ready chan struct{}
mu sync.Mutex mu sync.Mutex
initBegan bool initDone bool
taskQueues taskQueues taskQueues taskQueues
} }
...@@ -104,54 +100,35 @@ func (s *Service) snapshot() error { ...@@ -104,54 +100,35 @@ func (s *Service) snapshot() error {
return nil return nil
} }
// SetDataset sets dataset to dispatch for the master server. func getChunks(globPaths []string) ([]Chunk, error) {
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.initBegan {
// SetDataset already called. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return nil
}
s.initBegan = true
var chunks []Chunk var chunks []Chunk
var paths []string var paths []string
for _, s := range globPaths { for _, s := range globPaths {
match, err := filepath.Glob(s) match, err := filepath.Glob(s)
if err != nil { if err != nil {
panic(err) return nil, err
} }
paths = append(paths, match...) paths = append(paths, match...)
} }
if len(paths) == 0 { if len(paths) == 0 {
return errors.New("no valid datset specified") return nil, errors.New("no valid datset specified")
} }
for _, path := range paths { for _, path := range paths {
f, err := os.Open(path) f, err := os.Open(path)
if err != nil { if err != nil {
panic(err) return nil, err
} }
index, err := recordio.LoadIndex(f) index, err := recordio.LoadIndex(f)
if err != nil { if err != nil {
return err return nil, err
} }
err = f.Close() err = f.Close()
if err != nil { if err != nil {
return err return nil, err
} }
count := index.NumChunks() count := index.NumChunks()
...@@ -164,14 +141,41 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { ...@@ -164,14 +141,41 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
} }
} }
return chunks, nil
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.initDone {
// Already initialized. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return nil
}
chunks, err := getChunks(globPaths)
if err != nil {
return err
}
s.taskQueues.Todo = partition(chunks, s.chunksPerTask) s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
err := s.snapshot() err = s.snapshot()
if err != nil { if err != nil {
return err return err
} }
close(s.ready) close(s.ready)
s.initDone = true
return nil return nil
} }
...@@ -193,7 +197,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { ...@@ -193,7 +197,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// TODO(helin): client need to retry in this // TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't // error case. Gotcha: RPC client can't
// compare returned error with predefined // compare returned error with predefined
// erros like io.EOF. Because interface don't // errors like io.EOF. Because interface don't
// have same dynamic value when in different // have same dynamic value when in different
// process. // process.
return errors.New("no more available task") return errors.New("no more available task")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册