diff --git a/go/connection/conn.go b/go/connection/conn.go index 0bab2def1d9b4d05484fa2eb752ecf09b02aaf30..ea6bf972f641be9b58c4e8a8260d6c23de8d1163 100644 --- a/go/connection/conn.go +++ b/go/connection/conn.go @@ -62,6 +62,7 @@ func (c *Conn) Connect(addr string) error { c.waitConn = nil } } else { + client.Close() return errors.New("client already set from a concurrent goroutine") } diff --git a/go/master/service.go b/go/master/service.go index c80037a3b35927dd4ab7e2085cb6d7a6d116ec16..30859d92963f9b2bb67cea971c771c826c27c3fd 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -11,10 +11,6 @@ import ( "github.com/PaddlePaddle/recordio" ) -const ( - targetTaskCount = 300 -) - // Service is the master server service. type Service struct { chunksPerTask int @@ -23,7 +19,7 @@ type Service struct { ready chan struct{} mu sync.Mutex - initBegan bool + initDone bool taskQueues taskQueues } @@ -104,54 +100,35 @@ func (s *Service) snapshot() error { return nil } -// SetDataset sets dataset to dispatch for the master server. -// -// SetDataset can be call multiple times. But only the first call will -// be honored. -func (s *Service) SetDataset(globPaths []string, dummy *int) error { - if len(globPaths) == 0 { - return errors.New("no dataset specified") - } - - s.mu.Lock() - defer s.mu.Unlock() - if s.initBegan { - // SetDataset already called. All trainer will call - // SetDataset, but we only handle the first one. Treat - // other calls as successful but do nothing. - return nil - } - - s.initBegan = true - +func getChunks(globPaths []string) ([]Chunk, error) { var chunks []Chunk var paths []string for _, s := range globPaths { match, err := filepath.Glob(s) if err != nil { - panic(err) + return nil, err } paths = append(paths, match...) } if len(paths) == 0 { - return errors.New("no valid datset specified") + return nil, errors.New("no valid datset specified") } for _, path := range paths { f, err := os.Open(path) if err != nil { - panic(err) + return nil, err } index, err := recordio.LoadIndex(f) if err != nil { - return err + return nil, err } err = f.Close() if err != nil { - return err + return nil, err } count := index.NumChunks() @@ -164,14 +141,41 @@ func (s *Service) SetDataset(globPaths []string, dummy *int) error { } } + return chunks, nil +} + +// SetDataset sets dataset to dispatch for the master server. +// +// SetDataset can be call multiple times. But only the first call will +// be honored. +func (s *Service) SetDataset(globPaths []string, dummy *int) error { + if len(globPaths) == 0 { + return errors.New("no dataset specified") + } + + s.mu.Lock() + defer s.mu.Unlock() + if s.initDone { + // Already initialized. All trainer will call + // SetDataset, but we only handle the first one. Treat + // other calls as successful but do nothing. + return nil + } + + chunks, err := getChunks(globPaths) + if err != nil { + return err + } + s.taskQueues.Todo = partition(chunks, s.chunksPerTask) - err := s.snapshot() + err = s.snapshot() if err != nil { return err } close(s.ready) + s.initDone = true return nil } @@ -193,7 +197,7 @@ func (s *Service) GetTask(dummy int, task *Task) error { // TODO(helin): client need to retry in this // error case. Gotcha: RPC client can't // compare returned error with predefined - // erros like io.EOF. Because interface don't + // errors like io.EOF. Because interface don't // have same dynamic value when in different // process. return errors.New("no more available task")