提交 41af738a 编写于 作者: H Helin Wang

fix according to comments

上级 54e8263c
......@@ -62,6 +62,7 @@ func (c *Conn) Connect(addr string) error {
c.waitConn = nil
}
} else {
client.Close()
return errors.New("client already set from a concurrent goroutine")
}
......
......@@ -11,10 +11,6 @@ import (
"github.com/PaddlePaddle/recordio"
)
const (
targetTaskCount = 300
)
// Service is the master server service.
type Service struct {
chunksPerTask int
......@@ -23,7 +19,7 @@ type Service struct {
ready chan struct{}
mu sync.Mutex
initBegan bool
initDone bool
taskQueues taskQueues
}
......@@ -104,54 +100,35 @@ func (s *Service) snapshot() error {
return nil
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.initBegan {
// SetDataset already called. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return nil
}
s.initBegan = true
func getChunks(globPaths []string) ([]Chunk, error) {
var chunks []Chunk
var paths []string
for _, s := range globPaths {
match, err := filepath.Glob(s)
if err != nil {
panic(err)
return nil, err
}
paths = append(paths, match...)
}
if len(paths) == 0 {
return errors.New("no valid datset specified")
return nil, errors.New("no valid datset specified")
}
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
panic(err)
return nil, err
}
index, err := recordio.LoadIndex(f)
if err != nil {
return err
return nil, err
}
err = f.Close()
if err != nil {
return err
return nil, err
}
count := index.NumChunks()
......@@ -164,14 +141,41 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error {
}
}
return chunks, nil
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
s.mu.Lock()
defer s.mu.Unlock()
if s.initDone {
// Already initialized. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return nil
}
chunks, err := getChunks(globPaths)
if err != nil {
return err
}
s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
err := s.snapshot()
err = s.snapshot()
if err != nil {
return err
}
close(s.ready)
s.initDone = true
return nil
}
......@@ -193,7 +197,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// erros like io.EOF. Because interface don't
// errors like io.EOF. Because interface don't
// have same dynamic value when in different
// process.
return errors.New("no more available task")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册