service.go 8.4 KB
Newer Older
1 2 3
package master

import (
4 5 6
	"bytes"
	"compress/gzip"
	"encoding/gob"
7
	"errors"
8 9
	"os"
	"path/filepath"
10 11 12
	"sync"
	"time"

H
Helin Wang 已提交
13 14
	log "github.com/sirupsen/logrus"

15
	"github.com/PaddlePaddle/recordio"
16 17
)

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
const (
	dialTimeout = 5 * time.Second
)

// Store is the interface for save and load the master state.
type Store interface {
	Save([]byte) error
	Load() ([]byte, error)
}

// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
	Path  string
	Index recordio.Index // chunk index
}

// Task is the basic unit of data instances assigned to trainers.
type Task struct {
	ID     int
	Chunks []Chunk
}

type taskEntry struct {
	Epoch      int
	NumTimeout int
	Task       Task
}

type taskQueues struct {
	Todo    []taskEntry
	Pending map[int]taskEntry // map from task ID to task entry
	Done    []taskEntry
	Failed  []Task
}

53 54
// Service is the master server service.
type Service struct {
55 56 57 58
	chunksPerTask int
	timeoutDur    time.Duration
	timeoutMax    int
	ready         chan struct{}
59
	store         Store
60 61

	mu         sync.Mutex
H
Helin Wang 已提交
62
	initDone   bool
63 64 65
	taskQueues taskQueues
}

H
Helin Wang 已提交
66
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
67
	id := 0
H
Helin Wang 已提交
68 69
	if chunksPerTask <= 0 {
		chunksPerTask = 1
70 71 72 73 74
	}

	var result []taskEntry
	var cur taskEntry
	for i, c := range chunks {
H
Helin Wang 已提交
75
		if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
			cur.Task.ID = id
			id++
			result = append(result, cur)
			cur.Task.Chunks = nil
		}

		cur.Task.Chunks = append(cur.Task.Chunks, c)
	}

	if len(cur.Task.Chunks) > 0 {
		cur.Task.ID = id
		result = append(result, cur)
	}

	return result
}

// NewService creates a new service.
94
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) {
95
	s := &Service{}
96
	s.chunksPerTask = chunksPerTask
97 98 99 100
	s.timeoutDur = timeoutDur
	s.timeoutMax = timeoutMax
	s.taskQueues = taskQueues{}
	s.taskQueues.Pending = make(map[int]taskEntry)
101
	s.ready = make(chan struct{})
102 103 104 105 106
	s.store = store
	recovered, err := s.recover()
	if err != nil {
		return nil, err
	}
107

108 109 110 111 112
	if recovered {
		// Recovered. Now the state is already initialized,
		// and the master is ready.
		s.initDone = true
		close(s.ready)
113
		log.Info("Master recovered from saved state.")
114
	}
115

116
	return s, nil
117 118
}

119 120 121 122 123 124
// recover recovers service state from etcd.
func (s *Service) recover() (bool, error) {
	state, err := s.store.Load()
	if err != nil {
		return false, err
	}
125

126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
	if state == nil {
		log.Infoln("No state exists, not recovered.")
		return false, nil
	}

	log.Infof("Loaded snapshot of size: %d bytes.", len(state))
	gr, err := gzip.NewReader(bytes.NewReader(state))
	if err != nil {
		return false, err
	}

	dec := gob.NewDecoder(gr)
	var tqs taskQueues
	err = dec.Decode(&tqs)
	if err != nil {
		return false, err
	}

	err = gr.Close()
	if err != nil {
		// Only close failed, recover actually succeed, so
		// just log error.
		log.Errorln(err)
	}

	s.taskQueues = tqs
	return true, nil
153 154
}

155
// snapshot *must* be called with s.mu being held.
156
func (s *Service) snapshot() error {
157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
	// TOOD(helin): etcd request has a size limit, so the snapshot
	// size is limited by the max request size. We should either
	// divide the snapshot into smaller chunks and save under
	// different keys, or configure the request size to be big
	// enough:
	// https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44
	var buf bytes.Buffer
	gw := gzip.NewWriter(&buf)
	enc := gob.NewEncoder(gw)
	err := enc.Encode(s.taskQueues)
	if err != nil {
		return err
	}
	err = gw.Close()
	if err != nil {
		return err
	}

	state := buf.Bytes()
	log.Infof("Saving snapshot of size: %d bytes.", len(state))
	return s.store.Save(state)
178 179
}

H
Helin Wang 已提交
180
func readChunks(globPaths []string) ([]Chunk, error) {
181 182 183 184 185 186
	var chunks []Chunk
	var paths []string

	for _, s := range globPaths {
		match, err := filepath.Glob(s)
		if err != nil {
H
Helin Wang 已提交
187
			return nil, err
188 189 190 191 192
		}
		paths = append(paths, match...)
	}

	if len(paths) == 0 {
H
Helin Wang 已提交
193
		return nil, errors.New("no valid dataset specified")
194 195 196 197 198
	}

	for _, path := range paths {
		f, err := os.Open(path)
		if err != nil {
H
Helin Wang 已提交
199
			return nil, err
200 201 202 203
		}

		index, err := recordio.LoadIndex(f)
		if err != nil {
H
Helin Wang 已提交
204
			return nil, err
205 206 207
		}
		err = f.Close()
		if err != nil {
H
Helin Wang 已提交
208
			return nil, err
209 210 211 212 213 214 215 216 217 218 219 220
		}

		count := index.NumChunks()
		for i := 0; i < count; i++ {
			chunk := Chunk{
				Path:  path,
				Index: *index.ChunkIndex(i),
			}
			chunks = append(chunks, chunk)
		}
	}

H
Helin Wang 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
	return chunks, nil
}

// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
	if len(globPaths) == 0 {
		return errors.New("no dataset specified")
	}

	s.mu.Lock()
	defer s.mu.Unlock()
	if s.initDone {
		// Already initialized. All trainer will call
		// SetDataset, but we only handle the first one. Treat
		// other calls as successful but do nothing.
		return nil
	}

H
Helin Wang 已提交
242
	chunks, err := readChunks(globPaths)
H
Helin Wang 已提交
243 244 245 246
	if err != nil {
		return err
	}

247 248
	s.taskQueues.Todo = partition(chunks, s.chunksPerTask)

H
Helin Wang 已提交
249
	err = s.snapshot()
250
	if err != nil {
H
Helin Wang 已提交
251
		log.Errorln(err)
252 253 254 255
		return err
	}

	close(s.ready)
H
Helin Wang 已提交
256
	s.initDone = true
257 258 259
	return nil
}

H
Helin Wang 已提交
260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
	return func() {
		s.mu.Lock()
		defer s.mu.Unlock()

		t, ok := s.taskQueues.Pending[taskID]
		if !ok {
			return
		}

		if t.Epoch != epoch {
			// new epoch, task launched after the
			// schedule of this timeout check.
			return
		}

		defer func() {
			err := s.snapshot()
			if err != nil {
				log.Errorln(err)
			}
		}()

		delete(s.taskQueues.Pending, t.Task.ID)

		t.NumTimeout++
		if t.NumTimeout > s.timeoutMax {
287
			log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout)
H
Helin Wang 已提交
288 289 290 291
			s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
			return
		}

292
		log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout)
H
Helin Wang 已提交
293 294 295 296
		s.taskQueues.Todo = append(s.taskQueues.Todo, t)
	}
}

H
Helin Wang 已提交
297 298 299 300 301 302 303 304 305 306
// must be called with lock held.
func (s *Service) logFields() log.Fields {
	return log.Fields{
		"todoLen":    len(s.taskQueues.Todo),
		"pendingLen": len(s.taskQueues.Pending),
		"doneLen":    len(s.taskQueues.Done),
		"failedLen":  len(s.taskQueues.Failed),
	}
}

307 308
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
309 310 311 312
	select {
	case <-s.ready:
	}

313 314 315 316
	s.mu.Lock()
	defer s.mu.Unlock()

	if len(s.taskQueues.Todo) == 0 {
317 318
		if len(s.taskQueues.Done) == 0 {
			if len(s.taskQueues.Pending) == 0 {
H
Helin Wang 已提交
319
				err := errors.New("all task failed")
H
Helin Wang 已提交
320
				log.WithFields(s.logFields()).Warningln("All tasks failed.")
H
Helin Wang 已提交
321
				return err
322 323 324 325 326
			}

			// TODO(helin): client need to retry in this
			// error case. Gotcha: RPC client can't
			// compare returned error with predefined
H
Helin Wang 已提交
327 328 329 330 331
			// errors like io.EOF, because the error
			// instance deserialized from RPC is a
			// different instance than the error defined
			// in package. So we need to figure out a way
			// for client to check this error correctly.
H
Helin Wang 已提交
332
			err := errors.New("no more available task")
H
Helin Wang 已提交
333
			log.WithFields(s.logFields()).Warningln("No more available task.")
H
Helin Wang 已提交
334
			return err
335 336
		}
		s.taskQueues.Todo = s.taskQueues.Done
H
Helin Wang 已提交
337
		s.taskQueues.Done = nil
H
Helin Wang 已提交
338
		log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
339 340 341 342 343 344 345 346 347 348 349
	}

	t := s.taskQueues.Todo[0]
	t.Epoch++
	s.taskQueues.Todo = s.taskQueues.Todo[1:]
	s.taskQueues.Pending[t.Task.ID] = t
	err := s.snapshot()
	if err != nil {
		return err
	}

350
	*task = t.Task
H
Helin Wang 已提交
351
	log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID)
352

H
Helin Wang 已提交
353
	time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
354 355 356 357 358
	return nil
}

// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
359 360 361 362
	select {
	case <-s.ready:
	}

363 364 365 366 367
	s.mu.Lock()
	defer s.mu.Unlock()

	t, ok := s.taskQueues.Pending[taskID]
	if !ok {
H
Helin Wang 已提交
368
		err := errors.New("pending task not found")
H
Helin Wang 已提交
369
		log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
H
Helin Wang 已提交
370
		return err
371 372 373 374 375 376
	}

	// task finished, reset timeout
	t.NumTimeout = 0
	s.taskQueues.Done = append(s.taskQueues.Done, t)
	delete(s.taskQueues.Pending, taskID)
377

H
Helin Wang 已提交
378 379
	log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)

H
Helin Wang 已提交
380
	if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
H
Helin Wang 已提交
381
		log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.")
382
		s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
383 384 385
		s.taskQueues.Done = nil
	}

H
Helin Wang 已提交
386 387 388 389 390
	err := s.snapshot()
	if err != nil {
		log.Errorln(err)
	}
	return err
391
}