service.go 11.5 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package master

import (
18 19 20
	"bytes"
	"compress/gzip"
	"encoding/gob"
21
	"errors"
22
	"math/rand"
23 24
	"os"
	"path/filepath"
25 26 27
	"sync"
	"time"

H
Helin Wang 已提交
28 29
	log "github.com/sirupsen/logrus"

30
	"github.com/PaddlePaddle/recordio"
31 32
)

33 34 35 36
const (
	dialTimeout = 5 * time.Second
)

37 38 39 40 41 42 43 44 45 46 47 48
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")

// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")

// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")

// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")

49 50 51 52
// Store is the interface for save and load the master state.
type Store interface {
	Save([]byte) error
	Load() ([]byte, error)
H
Helin Wang 已提交
53
	Shutdown() error
54 55 56 57 58 59 60 61
}

// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
	Path  string
	Index recordio.Index // chunk index
}

G
gongweibao 已提交
62 63 64 65 66 67
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
	ID    int
	Epoch int
}

68 69
// Task is the basic unit of data instances assigned to trainers.
type Task struct {
G
gongweibao 已提交
70
	Meta   TaskMeta
71 72 73 74
	Chunks []Chunk
}

type taskEntry struct {
G
gongweibao 已提交
75 76 77
	Task Task
	// A task fails if it's timeout or trainer reports it exits unnormally.
	NumFailure int
78 79
}

80
type masterState struct {
81 82 83 84 85
	Todo    []taskEntry
	Pending map[int]taskEntry // map from task ID to task entry
	Done    []taskEntry
	Failed  []taskEntry
	CurPass int
86 87
}

88 89
// Service is the master server service.
type Service struct {
G
gongweibao 已提交
90 91 92 93
	chunksPerTask int
	timeoutDur    time.Duration
	failureMax    int
	store         Store
94

95 96 97
	ready    chan struct{}
	initDone bool

98 99 100 101 102
	mu sync.Mutex
	// State to be persisted to snapshot.
	state masterState
	// The trainer that is currently saving model. This state is
	// transient, does not need to be persisted to snapshot.
103
	savingTrainer string
104 105
}

H
Helin Wang 已提交
106
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
107 108 109 110 111 112
	// generate uniq id across job using nanosecond + randint + counter
	// FIXME(typhoonzero): this is a workaround, use uuid
	randStart := rand.Int()
	counter := 0
	timestamp := time.Now().Nanosecond()
	id := timestamp + randStart + counter
H
Helin Wang 已提交
113 114
	if chunksPerTask <= 0 {
		chunksPerTask = 1
115 116 117 118 119
	}

	var result []taskEntry
	var cur taskEntry
	for i, c := range chunks {
H
Helin Wang 已提交
120
		if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
G
gongweibao 已提交
121
			cur.Task.Meta.ID = id
122 123
			counter++
			id = timestamp + randStart + counter
124 125 126 127 128 129 130 131
			result = append(result, cur)
			cur.Task.Chunks = nil
		}

		cur.Task.Chunks = append(cur.Task.Chunks, c)
	}

	if len(cur.Task.Chunks) > 0 {
G
gongweibao 已提交
132
		cur.Task.Meta.ID = id
133 134 135 136 137 138 139
		result = append(result, cur)
	}

	return result
}

// NewService creates a new service.
G
gongweibao 已提交
140
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
141
	s := &Service{}
142
	s.chunksPerTask = chunksPerTask
143
	s.timeoutDur = timeoutDur
G
gongweibao 已提交
144
	s.failureMax = failureMax
145 146
	s.state = masterState{}
	s.state.Pending = make(map[int]taskEntry)
147
	s.ready = make(chan struct{})
148 149 150 151 152
	s.store = store
	recovered, err := s.recover()
	if err != nil {
		return nil, err
	}
153

154 155 156 157 158
	if recovered {
		// Recovered. Now the state is already initialized,
		// and the master is ready.
		s.initDone = true
		close(s.ready)
159
		log.Info("Master recovered from saved state.")
160
	}
161

162
	return s, nil
163 164
}

165 166 167 168 169 170
// recover recovers service state from etcd.
func (s *Service) recover() (bool, error) {
	state, err := s.store.Load()
	if err != nil {
		return false, err
	}
171

172 173 174 175 176 177 178 179 180 181 182 183
	if state == nil {
		log.Infoln("No state exists, not recovered.")
		return false, nil
	}

	log.Infof("Loaded snapshot of size: %d bytes.", len(state))
	gr, err := gzip.NewReader(bytes.NewReader(state))
	if err != nil {
		return false, err
	}

	dec := gob.NewDecoder(gr)
184
	var tqs masterState
185 186 187 188 189 190 191 192 193 194 195 196
	err = dec.Decode(&tqs)
	if err != nil {
		return false, err
	}

	err = gr.Close()
	if err != nil {
		// Only close failed, recover actually succeed, so
		// just log error.
		log.Errorln(err)
	}

197 198 199 200 201 202
	s.state = tqs
	log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.")
	for _, t := range s.state.Pending {
		time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
	}

203
	return true, nil
204 205
}

206
// snapshot *must* be called with s.mu being held.
207
func (s *Service) snapshot() error {
208
	// TODO(helin): etcd request has a size limit, so the snapshot
209 210 211 212 213 214 215 216
	// size is limited by the max request size. We should either
	// divide the snapshot into smaller chunks and save under
	// different keys, or configure the request size to be big
	// enough:
	// https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44
	var buf bytes.Buffer
	gw := gzip.NewWriter(&buf)
	enc := gob.NewEncoder(gw)
217
	err := enc.Encode(s.state)
218 219 220 221 222 223 224 225 226 227 228
	if err != nil {
		return err
	}
	err = gw.Close()
	if err != nil {
		return err
	}

	state := buf.Bytes()
	log.Infof("Saving snapshot of size: %d bytes.", len(state))
	return s.store.Save(state)
229 230
}

H
Helin Wang 已提交
231
func readChunks(globPaths []string) ([]Chunk, error) {
232 233 234 235 236 237
	var chunks []Chunk
	var paths []string

	for _, s := range globPaths {
		match, err := filepath.Glob(s)
		if err != nil {
H
Helin Wang 已提交
238
			return nil, err
239 240 241 242 243
		}
		paths = append(paths, match...)
	}

	if len(paths) == 0 {
H
Helin Wang 已提交
244
		return nil, errors.New("no valid dataset specified")
245 246 247 248 249
	}

	for _, path := range paths {
		f, err := os.Open(path)
		if err != nil {
H
Helin Wang 已提交
250
			return nil, err
251 252 253 254
		}

		index, err := recordio.LoadIndex(f)
		if err != nil {
H
Helin Wang 已提交
255
			return nil, err
256 257 258
		}
		err = f.Close()
		if err != nil {
H
Helin Wang 已提交
259
			return nil, err
260 261 262
		}

		count := index.NumChunks()
263
		log.Infof("readChunks: file %s has %d chunks", path, count)
264 265 266 267 268 269 270 271 272
		for i := 0; i < count; i++ {
			chunk := Chunk{
				Path:  path,
				Index: *index.ChunkIndex(i),
			}
			chunks = append(chunks, chunk)
		}
	}

H
Helin Wang 已提交
273 274 275 276 277 278 279
	return chunks, nil
}

// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
280
func (s *Service) SetDataset(globPaths []string, _ *int) error {
H
Helin Wang 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293
	if len(globPaths) == 0 {
		return errors.New("no dataset specified")
	}

	s.mu.Lock()
	defer s.mu.Unlock()
	if s.initDone {
		// Already initialized. All trainer will call
		// SetDataset, but we only handle the first one. Treat
		// other calls as successful but do nothing.
		return nil
	}

H
Helin Wang 已提交
294
	chunks, err := readChunks(globPaths)
H
Helin Wang 已提交
295 296 297 298
	if err != nil {
		return err
	}

299
	s.state.Todo = partition(chunks, s.chunksPerTask)
300

H
Helin Wang 已提交
301
	err = s.snapshot()
302
	if err != nil {
H
Helin Wang 已提交
303
		log.Errorln(err)
304 305 306
		return err
	}
	close(s.ready)
H
Helin Wang 已提交
307
	s.initDone = true
308 309 310
	return nil
}

311 312
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
G
gongweibao 已提交
313 314
func (s *Service) processFailedTask(t taskEntry, epoch int) {
	if t.Task.Meta.Epoch != epoch {
G
gongweibao 已提交
315 316 317 318 319 320 321 322 323 324 325 326
		// new epoch, task launched after the
		// schedule of this timeout check or failed status report.
		return
	}

	defer func() {
		err := s.snapshot()
		if err != nil {
			log.Errorln(err)
		}
	}()

327
	delete(s.state.Pending, t.Task.Meta.ID)
G
gongweibao 已提交
328

G
gongweibao 已提交
329 330 331
	t.NumFailure++
	if t.NumFailure > s.failureMax {
		log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
332
		s.state.Failed = append(s.state.Failed, t)
G
gongweibao 已提交
333 334 335
		return
	}

336
	log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
337
	s.state.Todo = append(s.state.Todo, t)
338
	return
G
gongweibao 已提交
339 340
}

H
Helin Wang 已提交
341 342 343 344 345
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
	return func() {
		s.mu.Lock()
		defer s.mu.Unlock()

346
		t, ok := s.state.Pending[taskID]
H
Helin Wang 已提交
347 348 349 350
		if !ok {
			return
		}

G
gongweibao 已提交
351
		s.processFailedTask(t, epoch)
H
Helin Wang 已提交
352 353 354
	}
}

H
Helin Wang 已提交
355 356 357
// must be called with lock held.
func (s *Service) logFields() log.Fields {
	return log.Fields{
358 359 360 361
		"todoLen":    len(s.state.Todo),
		"pendingLen": len(s.state.Pending),
		"doneLen":    len(s.state.Done),
		"failedLen":  len(s.state.Failed),
362
		"curPass":    s.state.CurPass,
H
Helin Wang 已提交
363 364 365
	}
}

366
// GetTask gets a new task from the service.
367 368
// passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
369 370 371 372
	select {
	case <-s.ready:
	}

373 374
	s.mu.Lock()
	defer s.mu.Unlock()
375
	if passID < s.state.CurPass {
376 377
		return ErrPassBefore
	}
378
	if passID > s.state.CurPass {
379 380 381 382
		// Client may get run to pass after master when one client faster than the
		// other
		return ErrPassAfter
	}
383

384 385
	if len(s.state.Todo) == 0 {
		if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
386 387
			log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
			return ErrAllTaskFailed
388
		}
389 390
		log.WithFields(s.logFields()).Warningln("No more available task.")
		return ErrNoMoreAvailable
391 392
	}

393
	t := s.state.Todo[0]
G
gongweibao 已提交
394
	t.Task.Meta.Epoch++
395 396
	s.state.Todo = s.state.Todo[1:]
	s.state.Pending[t.Task.Meta.ID] = t
397 398 399 400 401
	err := s.snapshot()
	if err != nil {
		return err
	}

402
	*task = t.Task
G
gongweibao 已提交
403
	log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta)
404

G
gongweibao 已提交
405
	time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
406 407 408 409
	return nil
}

// TaskFinished tell the service that a task is finished.
410
func (s *Service) TaskFinished(taskID int, dummy *int) error {
411 412 413 414
	select {
	case <-s.ready:
	}

415 416 417
	s.mu.Lock()
	defer s.mu.Unlock()

418
	t, ok := s.state.Pending[taskID]
419
	if !ok {
H
Helin Wang 已提交
420
		log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
G
gongweibao 已提交
421
		return nil
422 423 424
	}

	// task finished, reset timeout
G
gongweibao 已提交
425
	t.NumFailure = 0
426 427
	s.state.Done = append(s.state.Done, t)
	delete(s.state.Pending, taskID)
428

H
Helin Wang 已提交
429
	log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
430
	if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
431
		// increase master side pass count if all tasks finished
432
		s.state.CurPass++
433
		s.state.Todo = append(s.state.Done, s.state.Failed...)
434
		s.state.Done = []taskEntry{}
435
		// TODO(typhoonzero): deal with failed tasks
436 437
		s.state.Failed = []taskEntry{}
		log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass)
438 439
	}

H
Helin Wang 已提交
440 441 442 443 444
	err := s.snapshot()
	if err != nil {
		log.Errorln(err)
	}
	return err
445
}
G
gongweibao 已提交
446

G
gongweibao 已提交
447
// TaskFailed tells the service that a task is failed.
448
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
G
gongweibao 已提交
449 450 451 452 453 454 455
	select {
	case <-s.ready:
	}

	s.mu.Lock()
	defer s.mu.Unlock()

456
	t, ok := s.state.Pending[meta.ID]
G
gongweibao 已提交
457
	if !ok {
G
gongweibao 已提交
458
		log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
G
gongweibao 已提交
459
		return nil
G
gongweibao 已提交
460 461
	}

G
gongweibao 已提交
462
	s.processFailedTask(t, meta.Epoch)
G
gongweibao 已提交
463 464
	return nil
}
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503

// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
	TrainerID string
	BlockDur  time.Duration
}

// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
	s.mu.Lock()
	defer s.mu.Unlock()

	if req.TrainerID == "" {
		return errors.New("trainer id is empty")
	}

	if s.savingTrainer == "" {
		*need = true
	} else {
		if req.TrainerID == s.savingTrainer {
			// save trainer asked to save model again
			*need = true
		} else {
			*need = false
		}
	}

	if *need {
		s.savingTrainer = req.TrainerID
		time.AfterFunc(req.BlockDur, func() {
			s.mu.Lock()
			s.savingTrainer = ""
			s.mu.Unlock()
		})
	}

	return nil
}