service.go 11.5 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package master

import (
18 19 20
	"bytes"
	"compress/gzip"
	"encoding/gob"
21
	"errors"
22
	"math/rand"
23 24
	"os"
	"path/filepath"
25 26 27
	"sync"
	"time"

H
Helin Wang 已提交
28 29
	log "github.com/sirupsen/logrus"

30
	"github.com/PaddlePaddle/recordio"
31 32
)

33 34 35 36
const (
	dialTimeout = 5 * time.Second
)

37 38 39 40 41 42 43 44 45 46 47 48
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")

// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")

// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")

// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")

49 50 51 52
// Store is the interface for save and load the master state.
type Store interface {
	Save([]byte) error
	Load() ([]byte, error)
H
Helin Wang 已提交
53
	Shutdown() error
54 55 56 57 58 59 60 61
}

// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
	Path  string
	Index recordio.Index // chunk index
}

G
gongweibao 已提交
62 63 64 65 66 67
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
	ID    int
	Epoch int
}

68 69
// Task is the basic unit of data instances assigned to trainers.
type Task struct {
G
gongweibao 已提交
70
	Meta   TaskMeta
71 72 73 74
	Chunks []Chunk
}

type taskEntry struct {
G
gongweibao 已提交
75 76 77
	Task Task
	// A task fails if it's timeout or trainer reports it exits unnormally.
	NumFailure int
78 79
}

80 81 82 83 84 85 86
type masterState struct {
	Todo     []taskEntry
	Pending  map[int]taskEntry // map from task ID to task entry
	Done     []taskEntry
	Failed   []taskEntry
	CurPass  int
	JobTasks []taskEntry
87 88
}

89 90
// Service is the master server service.
type Service struct {
G
gongweibao 已提交
91 92 93 94
	chunksPerTask int
	timeoutDur    time.Duration
	failureMax    int
	store         Store
95

96 97 98
	ready    chan struct{}
	initDone bool

99 100 101 102 103
	mu sync.Mutex
	// State to be persisted to snapshot.
	state masterState
	// The trainer that is currently saving model. This state is
	// transient, does not need to be persisted to snapshot.
104
	savingTrainer string
105 106
}

H
Helin Wang 已提交
107
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
108 109 110 111 112 113
	// generate uniq id across job using nanosecond + randint + counter
	// FIXME(typhoonzero): this is a workaround, use uuid
	randStart := rand.Int()
	counter := 0
	timestamp := time.Now().Nanosecond()
	id := timestamp + randStart + counter
H
Helin Wang 已提交
114 115
	if chunksPerTask <= 0 {
		chunksPerTask = 1
116 117 118 119 120
	}

	var result []taskEntry
	var cur taskEntry
	for i, c := range chunks {
H
Helin Wang 已提交
121
		if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
G
gongweibao 已提交
122
			cur.Task.Meta.ID = id
123 124
			counter++
			id = timestamp + randStart + counter
125 126 127 128 129 130 131 132
			result = append(result, cur)
			cur.Task.Chunks = nil
		}

		cur.Task.Chunks = append(cur.Task.Chunks, c)
	}

	if len(cur.Task.Chunks) > 0 {
G
gongweibao 已提交
133
		cur.Task.Meta.ID = id
134 135 136 137 138 139 140
		result = append(result, cur)
	}

	return result
}

// NewService creates a new service.
G
gongweibao 已提交
141
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
142
	s := &Service{}
143
	s.chunksPerTask = chunksPerTask
144
	s.timeoutDur = timeoutDur
G
gongweibao 已提交
145
	s.failureMax = failureMax
146 147
	s.state = masterState{}
	s.state.Pending = make(map[int]taskEntry)
148
	s.ready = make(chan struct{})
149 150 151 152 153
	s.store = store
	recovered, err := s.recover()
	if err != nil {
		return nil, err
	}
154

155 156 157 158 159
	if recovered {
		// Recovered. Now the state is already initialized,
		// and the master is ready.
		s.initDone = true
		close(s.ready)
160
		log.Info("Master recovered from saved state.")
161
	}
162

163
	return s, nil
164 165
}

166 167 168 169 170 171
// recover recovers service state from etcd.
func (s *Service) recover() (bool, error) {
	state, err := s.store.Load()
	if err != nil {
		return false, err
	}
172

173 174 175 176 177 178 179 180 181 182 183 184
	if state == nil {
		log.Infoln("No state exists, not recovered.")
		return false, nil
	}

	log.Infof("Loaded snapshot of size: %d bytes.", len(state))
	gr, err := gzip.NewReader(bytes.NewReader(state))
	if err != nil {
		return false, err
	}

	dec := gob.NewDecoder(gr)
185
	var tqs masterState
186 187 188 189 190 191 192 193 194 195 196 197
	err = dec.Decode(&tqs)
	if err != nil {
		return false, err
	}

	err = gr.Close()
	if err != nil {
		// Only close failed, recover actually succeed, so
		// just log error.
		log.Errorln(err)
	}

198 199 200 201 202 203
	s.state = tqs
	log.WithFields(s.logFields()).Infof("Master recovered from snapshot, scheduling pending task timeout check.")
	for _, t := range s.state.Pending {
		time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
	}

204
	return true, nil
205 206
}

207
// snapshot *must* be called with s.mu being held.
208
func (s *Service) snapshot() error {
209
	// TODO(helin): etcd request has a size limit, so the snapshot
210 211 212 213 214 215 216 217
	// size is limited by the max request size. We should either
	// divide the snapshot into smaller chunks and save under
	// different keys, or configure the request size to be big
	// enough:
	// https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44
	var buf bytes.Buffer
	gw := gzip.NewWriter(&buf)
	enc := gob.NewEncoder(gw)
218
	err := enc.Encode(s.state)
219 220 221 222 223 224 225 226 227 228 229
	if err != nil {
		return err
	}
	err = gw.Close()
	if err != nil {
		return err
	}

	state := buf.Bytes()
	log.Infof("Saving snapshot of size: %d bytes.", len(state))
	return s.store.Save(state)
230 231
}

H
Helin Wang 已提交
232
func readChunks(globPaths []string) ([]Chunk, error) {
233 234 235 236 237 238
	var chunks []Chunk
	var paths []string

	for _, s := range globPaths {
		match, err := filepath.Glob(s)
		if err != nil {
H
Helin Wang 已提交
239
			return nil, err
240 241 242 243 244
		}
		paths = append(paths, match...)
	}

	if len(paths) == 0 {
H
Helin Wang 已提交
245
		return nil, errors.New("no valid dataset specified")
246 247 248 249 250
	}

	for _, path := range paths {
		f, err := os.Open(path)
		if err != nil {
H
Helin Wang 已提交
251
			return nil, err
252 253 254 255
		}

		index, err := recordio.LoadIndex(f)
		if err != nil {
H
Helin Wang 已提交
256
			return nil, err
257 258 259
		}
		err = f.Close()
		if err != nil {
H
Helin Wang 已提交
260
			return nil, err
261 262 263
		}

		count := index.NumChunks()
264
		log.Infof("readChunks: file %s has %d chunks", path, count)
265 266 267 268 269 270 271 272 273
		for i := 0; i < count; i++ {
			chunk := Chunk{
				Path:  path,
				Index: *index.ChunkIndex(i),
			}
			chunks = append(chunks, chunk)
		}
	}

H
Helin Wang 已提交
274 275 276 277 278 279 280
	return chunks, nil
}

// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
281
func (s *Service) SetDataset(globPaths []string, _ *int) error {
H
Helin Wang 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294
	if len(globPaths) == 0 {
		return errors.New("no dataset specified")
	}

	s.mu.Lock()
	defer s.mu.Unlock()
	if s.initDone {
		// Already initialized. All trainer will call
		// SetDataset, but we only handle the first one. Treat
		// other calls as successful but do nothing.
		return nil
	}

H
Helin Wang 已提交
295
	chunks, err := readChunks(globPaths)
H
Helin Wang 已提交
296 297 298 299
	if err != nil {
		return err
	}

300 301
	s.state.JobTasks = partition(chunks, s.chunksPerTask)
	s.state.Todo = s.state.JobTasks
302

H
Helin Wang 已提交
303
	err = s.snapshot()
304
	if err != nil {
H
Helin Wang 已提交
305
		log.Errorln(err)
306 307 308
		return err
	}
	close(s.ready)
H
Helin Wang 已提交
309
	s.initDone = true
310 311 312
	return nil
}

313 314
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
G
gongweibao 已提交
315 316
func (s *Service) processFailedTask(t taskEntry, epoch int) {
	if t.Task.Meta.Epoch != epoch {
G
gongweibao 已提交
317 318 319 320 321 322 323 324 325 326 327 328
		// new epoch, task launched after the
		// schedule of this timeout check or failed status report.
		return
	}

	defer func() {
		err := s.snapshot()
		if err != nil {
			log.Errorln(err)
		}
	}()

329
	delete(s.state.Pending, t.Task.Meta.ID)
G
gongweibao 已提交
330

G
gongweibao 已提交
331 332 333
	t.NumFailure++
	if t.NumFailure > s.failureMax {
		log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
334
		s.state.Failed = append(s.state.Failed, t)
G
gongweibao 已提交
335 336 337
		return
	}

338
	log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
339
	s.state.Todo = append(s.state.Todo, t)
340
	return
G
gongweibao 已提交
341 342
}

H
Helin Wang 已提交
343 344 345 346 347
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
	return func() {
		s.mu.Lock()
		defer s.mu.Unlock()

348
		t, ok := s.state.Pending[taskID]
H
Helin Wang 已提交
349 350 351 352
		if !ok {
			return
		}

G
gongweibao 已提交
353
		s.processFailedTask(t, epoch)
H
Helin Wang 已提交
354 355 356
	}
}

H
Helin Wang 已提交
357 358 359
// must be called with lock held.
func (s *Service) logFields() log.Fields {
	return log.Fields{
360 361 362 363
		"todoLen":    len(s.state.Todo),
		"pendingLen": len(s.state.Pending),
		"doneLen":    len(s.state.Done),
		"failedLen":  len(s.state.Failed),
H
Helin Wang 已提交
364 365 366
	}
}

367
// GetTask gets a new task from the service.
368 369
// passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
370 371 372 373
	select {
	case <-s.ready:
	}

374 375
	s.mu.Lock()
	defer s.mu.Unlock()
376
	if passID < s.state.CurPass {
377 378
		return ErrPassBefore
	}
379
	if passID > s.state.CurPass {
380 381 382 383
		// Client may get run to pass after master when one client faster than the
		// other
		return ErrPassAfter
	}
384

385 386
	if len(s.state.Todo) == 0 {
		if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
387 388
			log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
			return ErrAllTaskFailed
389
		}
390 391
		log.WithFields(s.logFields()).Warningln("No more available task.")
		return ErrNoMoreAvailable
392 393
	}

394
	t := s.state.Todo[0]
G
gongweibao 已提交
395
	t.Task.Meta.Epoch++
396 397
	s.state.Todo = s.state.Todo[1:]
	s.state.Pending[t.Task.Meta.ID] = t
398 399 400 401 402
	err := s.snapshot()
	if err != nil {
		return err
	}

403
	*task = t.Task
G
gongweibao 已提交
404
	log.WithFields(s.logFields()).Infof("Task #%v dispatched.", t.Task.Meta)
405

G
gongweibao 已提交
406
	time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
407 408 409 410
	return nil
}

// TaskFinished tell the service that a task is finished.
411
func (s *Service) TaskFinished(taskID int, dummy *int) error {
412 413 414 415
	select {
	case <-s.ready:
	}

416 417 418
	s.mu.Lock()
	defer s.mu.Unlock()

419
	t, ok := s.state.Pending[taskID]
420
	if !ok {
H
Helin Wang 已提交
421
		log.WithFields(s.logFields()).Warningln("Pending task #%d not found.", taskID)
G
gongweibao 已提交
422
		return nil
423 424 425
	}

	// task finished, reset timeout
G
gongweibao 已提交
426
	t.NumFailure = 0
427 428
	s.state.Done = append(s.state.Done, t)
	delete(s.state.Pending, taskID)
429

H
Helin Wang 已提交
430
	log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
431
	if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
432
		// increase master side pass count if all tasks finished
433 434 435
		s.state.CurPass++
		s.state.Todo = s.state.JobTasks
		s.state.Done = []taskEntry{}
436
		// TODO(typhoonzero): deal with failed tasks
437 438
		s.state.Failed = []taskEntry{}
		log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.state.CurPass)
439 440
	}

H
Helin Wang 已提交
441 442 443 444 445
	err := s.snapshot()
	if err != nil {
		log.Errorln(err)
	}
	return err
446
}
G
gongweibao 已提交
447

G
gongweibao 已提交
448
// TaskFailed tells the service that a task is failed.
449
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
G
gongweibao 已提交
450 451 452 453 454 455 456
	select {
	case <-s.ready:
	}

	s.mu.Lock()
	defer s.mu.Unlock()

457
	t, ok := s.state.Pending[meta.ID]
G
gongweibao 已提交
458
	if !ok {
G
gongweibao 已提交
459
		log.WithFields(s.logFields()).Warningln("TaskFailed:Pending task #%v not found.", t.Task.Meta)
G
gongweibao 已提交
460
		return nil
G
gongweibao 已提交
461 462
	}

G
gongweibao 已提交
463
	s.processFailedTask(t, meta.Epoch)
G
gongweibao 已提交
464 465
	return nil
}
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504

// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
	TrainerID string
	BlockDur  time.Duration
}

// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
	s.mu.Lock()
	defer s.mu.Unlock()

	if req.TrainerID == "" {
		return errors.New("trainer id is empty")
	}

	if s.savingTrainer == "" {
		*need = true
	} else {
		if req.TrainerID == s.savingTrainer {
			// save trainer asked to save model again
			*need = true
		} else {
			*need = false
		}
	}

	if *need {
		s.savingTrainer = req.TrainerID
		time.AfterFunc(req.BlockDur, func() {
			s.mu.Lock()
			s.savingTrainer = ""
			s.mu.Unlock()
		})
	}

	return nil
}