service.go 8.4 KB
Newer Older
1 2 3
package master

import (
4 5 6
	"bytes"
	"compress/gzip"
	"encoding/gob"
7
	"errors"
8 9
	"os"
	"path/filepath"
10 11 12
	"sync"
	"time"

H
Helin Wang 已提交
13 14
	log "github.com/sirupsen/logrus"

15
	"github.com/PaddlePaddle/recordio"
16 17
)

18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
const (
	dialTimeout = 5 * time.Second
)

// Store is the interface for save and load the master state.
type Store interface {
	Save([]byte) error
	Load() ([]byte, error)
}

// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
	Path  string
	Index recordio.Index // chunk index
}

// Task is the basic unit of data instances assigned to trainers.
type Task struct {
	ID     int
	Chunks []Chunk
}

type taskEntry struct {
	Epoch      int
	NumTimeout int
	Task       Task
}

type taskQueues struct {
	Todo    []taskEntry
	Pending map[int]taskEntry // map from task ID to task entry
	Done    []taskEntry
	Failed  []Task
}

53 54
// Service is the master server service.
type Service struct {
55 56 57 58
	chunksPerTask int
	timeoutDur    time.Duration
	timeoutMax    int
	ready         chan struct{}
59
	store         Store
60 61

	mu         sync.Mutex
H
Helin Wang 已提交
62
	initDone   bool
63 64 65
	taskQueues taskQueues
}

H
Helin Wang 已提交
66
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
67
	id := 0
H
Helin Wang 已提交
68 69
	if chunksPerTask <= 0 {
		chunksPerTask = 1
70 71 72 73 74
	}

	var result []taskEntry
	var cur taskEntry
	for i, c := range chunks {
H
Helin Wang 已提交
75
		if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
			cur.Task.ID = id
			id++
			result = append(result, cur)
			cur.Task.Chunks = nil
		}

		cur.Task.Chunks = append(cur.Task.Chunks, c)
	}

	if len(cur.Task.Chunks) > 0 {
		cur.Task.ID = id
		result = append(result, cur)
	}

	return result
}

// NewService creates a new service.
94
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) (*Service, error) {
95
	s := &Service{}
96
	s.chunksPerTask = chunksPerTask
97 98 99 100
	s.timeoutDur = timeoutDur
	s.timeoutMax = timeoutMax
	s.taskQueues = taskQueues{}
	s.taskQueues.Pending = make(map[int]taskEntry)
101
	s.ready = make(chan struct{})
102 103 104 105 106
	s.store = store
	recovered, err := s.recover()
	if err != nil {
		return nil, err
	}
107

108 109 110 111 112 113
	if recovered {
		// Recovered. Now the state is already initialized,
		// and the master is ready.
		s.initDone = true
		close(s.ready)
	}
114

115
	return s, nil
116 117
}

118 119 120 121 122 123
// recover recovers service state from etcd.
func (s *Service) recover() (bool, error) {
	state, err := s.store.Load()
	if err != nil {
		return false, err
	}
124

125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
	if state == nil {
		log.Infoln("No state exists, not recovered.")
		return false, nil
	}

	log.Infof("Loaded snapshot of size: %d bytes.", len(state))
	gr, err := gzip.NewReader(bytes.NewReader(state))
	if err != nil {
		return false, err
	}

	dec := gob.NewDecoder(gr)
	var tqs taskQueues
	err = dec.Decode(&tqs)
	if err != nil {
		return false, err
	}

	err = gr.Close()
	if err != nil {
		// Only close failed, recover actually succeed, so
		// just log error.
		log.Errorln(err)
	}

	s.taskQueues = tqs
	return true, nil
152 153
}

154
// snapshot *must* be called with s.mu being held.
155
func (s *Service) snapshot() error {
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176
	// TOOD(helin): etcd request has a size limit, so the snapshot
	// size is limited by the max request size. We should either
	// divide the snapshot into smaller chunks and save under
	// different keys, or configure the request size to be big
	// enough:
	// https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44
	var buf bytes.Buffer
	gw := gzip.NewWriter(&buf)
	enc := gob.NewEncoder(gw)
	err := enc.Encode(s.taskQueues)
	if err != nil {
		return err
	}
	err = gw.Close()
	if err != nil {
		return err
	}

	state := buf.Bytes()
	log.Infof("Saving snapshot of size: %d bytes.", len(state))
	return s.store.Save(state)
177 178
}

H
Helin Wang 已提交
179
func readChunks(globPaths []string) ([]Chunk, error) {
180 181 182 183 184 185
	var chunks []Chunk
	var paths []string

	for _, s := range globPaths {
		match, err := filepath.Glob(s)
		if err != nil {
H
Helin Wang 已提交
186
			return nil, err
187 188 189 190 191
		}
		paths = append(paths, match...)
	}

	if len(paths) == 0 {
H
Helin Wang 已提交
192
		return nil, errors.New("no valid dataset specified")
193 194 195 196 197
	}

	for _, path := range paths {
		f, err := os.Open(path)
		if err != nil {
H
Helin Wang 已提交
198
			return nil, err
199 200 201 202
		}

		index, err := recordio.LoadIndex(f)
		if err != nil {
H
Helin Wang 已提交
203
			return nil, err
204 205 206
		}
		err = f.Close()
		if err != nil {
H
Helin Wang 已提交
207
			return nil, err
208 209 210 211 212 213 214 215 216 217 218 219
		}

		count := index.NumChunks()
		for i := 0; i < count; i++ {
			chunk := Chunk{
				Path:  path,
				Index: *index.ChunkIndex(i),
			}
			chunks = append(chunks, chunk)
		}
	}

H
Helin Wang 已提交
220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
	return chunks, nil
}

// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
	if len(globPaths) == 0 {
		return errors.New("no dataset specified")
	}

	s.mu.Lock()
	defer s.mu.Unlock()
	if s.initDone {
		// Already initialized. All trainer will call
		// SetDataset, but we only handle the first one. Treat
		// other calls as successful but do nothing.
		return nil
	}

H
Helin Wang 已提交
241
	chunks, err := readChunks(globPaths)
H
Helin Wang 已提交
242 243 244 245
	if err != nil {
		return err
	}

246 247
	s.taskQueues.Todo = partition(chunks, s.chunksPerTask)

H
Helin Wang 已提交
248
	err = s.snapshot()
249
	if err != nil {
H
Helin Wang 已提交
250
		log.Errorln(err)
251 252 253 254
		return err
	}

	close(s.ready)
H
Helin Wang 已提交
255
	s.initDone = true
256 257 258
	return nil
}

H
Helin Wang 已提交
259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
	return func() {
		s.mu.Lock()
		defer s.mu.Unlock()

		t, ok := s.taskQueues.Pending[taskID]
		if !ok {
			return
		}

		if t.Epoch != epoch {
			// new epoch, task launched after the
			// schedule of this timeout check.
			return
		}

		defer func() {
			err := s.snapshot()
			if err != nil {
				log.Errorln(err)
			}
		}()

		delete(s.taskQueues.Pending, t.Task.ID)

		t.NumTimeout++
		if t.NumTimeout > s.timeoutMax {
286
			log.Warningf("Task %v timed out %d times, discard.", t.Task, t.NumTimeout)
H
Helin Wang 已提交
287 288 289 290
			s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
			return
		}

291
		log.Warningf("Task %v timed out %d times, retry.", t.Task, t.NumTimeout)
H
Helin Wang 已提交
292 293 294 295
		s.taskQueues.Todo = append(s.taskQueues.Todo, t)
	}
}

H
Helin Wang 已提交
296 297 298 299 300 301 302 303 304 305
// must be called with lock held.
func (s *Service) logFields() log.Fields {
	return log.Fields{
		"todoLen":    len(s.taskQueues.Todo),
		"pendingLen": len(s.taskQueues.Pending),
		"doneLen":    len(s.taskQueues.Done),
		"failedLen":  len(s.taskQueues.Failed),
	}
}

306 307
// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
308 309 310 311
	select {
	case <-s.ready:
	}

312 313 314 315
	s.mu.Lock()
	defer s.mu.Unlock()

	if len(s.taskQueues.Todo) == 0 {
316 317
		if len(s.taskQueues.Done) == 0 {
			if len(s.taskQueues.Pending) == 0 {
H
Helin Wang 已提交
318
				err := errors.New("all task failed")
H
Helin Wang 已提交
319
				log.WithFields(s.logFields()).Warningln("All tasks failed.")
H
Helin Wang 已提交
320
				return err
321 322 323 324 325
			}

			// TODO(helin): client need to retry in this
			// error case. Gotcha: RPC client can't
			// compare returned error with predefined
H
Helin Wang 已提交
326 327 328 329 330
			// errors like io.EOF, because the error
			// instance deserialized from RPC is a
			// different instance than the error defined
			// in package. So we need to figure out a way
			// for client to check this error correctly.
H
Helin Wang 已提交
331
			err := errors.New("no more available task")
H
Helin Wang 已提交
332
			log.WithFields(s.logFields()).Warningln("No more available task.")
H
Helin Wang 已提交
333
			return err
334 335
		}
		s.taskQueues.Todo = s.taskQueues.Done
H
Helin Wang 已提交
336
		s.taskQueues.Done = nil
H
Helin Wang 已提交
337
		log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
338 339 340 341 342 343 344 345 346 347 348
	}

	t := s.taskQueues.Todo[0]
	t.Epoch++
	s.taskQueues.Todo = s.taskQueues.Todo[1:]
	s.taskQueues.Pending[t.Task.ID] = t
	err := s.snapshot()
	if err != nil {
		return err
	}

349
	*task = t.Task
H
Helin Wang 已提交
350
	log.WithFields(s.logFields()).Infof("Task #%d dispatched.", task.ID)
351

H
Helin Wang 已提交
352
	time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
353 354 355 356 357
	return nil
}

// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
358 359 360 361
	select {
	case <-s.ready:
	}

362 363 364 365 366
	s.mu.Lock()
	defer s.mu.Unlock()

	t, ok := s.taskQueues.Pending[taskID]
	if !ok {
H
Helin Wang 已提交
367
		err := errors.New("pending task not found")
H
Helin Wang 已提交
368
		log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
H
Helin Wang 已提交
369
		return err
370 371 372 373 374 375
	}

	// task finished, reset timeout
	t.NumTimeout = 0
	s.taskQueues.Done = append(s.taskQueues.Done, t)
	delete(s.taskQueues.Pending, taskID)
376

H
Helin Wang 已提交
377 378
	log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)

H
Helin Wang 已提交
379
	if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
H
Helin Wang 已提交
380
		log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.")
381
		s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
382 383 384
		s.taskQueues.Done = nil
	}

H
Helin Wang 已提交
385 386 387 388 389
	err := s.snapshot()
	if err != nil {
		log.Errorln(err)
	}
	return err
390
}