service.go 11.7 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package master

import (
18 19 20
	"bytes"
	"compress/gzip"
	"encoding/gob"
21
	"errors"
22
	"math/rand"
23 24
	"os"
	"path/filepath"
25 26 27
	"sync"
	"time"

28
	log "github.com/inconshreveable/log15"
H
Helin Wang 已提交
29

30
	"github.com/PaddlePaddle/recordio"
31 32
)

33 34 35 36
const (
	dialTimeout = 5 * time.Second
)

37 38 39 40 41 42 43 44 45 46 47 48
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")

// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")

// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")

// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")

49 50 51 52
// Store is the interface for save and load the master state.
type Store interface {
	Save([]byte) error
	Load() ([]byte, error)
H
Helin Wang 已提交
53
	Shutdown() error
54 55 56 57 58 59 60 61
}

// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
	Path  string
	Index recordio.Index // chunk index
}

G
gongweibao 已提交
62 63 64 65 66 67
// TaskMeta is a struct which stores task's meta info.
type TaskMeta struct {
	ID    int
	Epoch int
}

68 69
// Task is the basic unit of data instances assigned to trainers.
type Task struct {
G
gongweibao 已提交
70
	Meta   TaskMeta
71 72 73 74
	Chunks []Chunk
}

type taskEntry struct {
G
gongweibao 已提交
75 76 77
	Task Task
	// A task fails if it's timeout or trainer reports it exits unnormally.
	NumFailure int
78 79
}

80
type masterState struct {
81 82 83 84 85
	Todo    []taskEntry
	Pending map[int]taskEntry // map from task ID to task entry
	Done    []taskEntry
	Failed  []taskEntry
	CurPass int
86 87
}

88 89
// Service is the master server service.
type Service struct {
G
gongweibao 已提交
90 91 92 93
	chunksPerTask int
	timeoutDur    time.Duration
	failureMax    int
	store         Store
94

95 96 97
	ready    chan struct{}
	initDone bool

98 99 100 101 102
	mu sync.Mutex
	// State to be persisted to snapshot.
	state masterState
	// The trainer that is currently saving model. This state is
	// transient, does not need to be persisted to snapshot.
103
	savingTrainer string
104 105
}

H
Helin Wang 已提交
106
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
107 108 109 110 111 112
	// generate uniq id across job using nanosecond + randint + counter
	// FIXME(typhoonzero): this is a workaround, use uuid
	randStart := rand.Int()
	counter := 0
	timestamp := time.Now().Nanosecond()
	id := timestamp + randStart + counter
H
Helin Wang 已提交
113 114
	if chunksPerTask <= 0 {
		chunksPerTask = 1
115 116 117 118 119
	}

	var result []taskEntry
	var cur taskEntry
	for i, c := range chunks {
H
Helin Wang 已提交
120
		if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
G
gongweibao 已提交
121
			cur.Task.Meta.ID = id
122 123
			counter++
			id = timestamp + randStart + counter
124 125 126 127 128 129 130 131
			result = append(result, cur)
			cur.Task.Chunks = nil
		}

		cur.Task.Chunks = append(cur.Task.Chunks, c)
	}

	if len(cur.Task.Chunks) > 0 {
G
gongweibao 已提交
132
		cur.Task.Meta.ID = id
133 134 135 136 137 138 139
		result = append(result, cur)
	}

	return result
}

// NewService creates a new service.
G
gongweibao 已提交
140
func NewService(store Store, chunksPerTask int, timeoutDur time.Duration, failureMax int) (*Service, error) {
141
	s := &Service{}
142
	s.chunksPerTask = chunksPerTask
143
	s.timeoutDur = timeoutDur
G
gongweibao 已提交
144
	s.failureMax = failureMax
145 146
	s.state = masterState{}
	s.state.Pending = make(map[int]taskEntry)
147
	s.ready = make(chan struct{})
148 149 150 151 152
	s.store = store
	recovered, err := s.recover()
	if err != nil {
		return nil, err
	}
153

154 155 156 157 158
	if recovered {
		// Recovered. Now the state is already initialized,
		// and the master is ready.
		s.initDone = true
		close(s.ready)
159
		log.Info("Master recovered from saved state.")
160
	}
161

162
	return s, nil
163 164
}

165 166 167 168 169 170
// recover recovers service state from etcd.
func (s *Service) recover() (bool, error) {
	state, err := s.store.Load()
	if err != nil {
		return false, err
	}
171

172
	if state == nil {
173
		log.Info("No state exists, not recovered.")
174 175 176
		return false, nil
	}

177
	log.Info("Loaded snapshot.", log.Ctx{"size": len(state)})
178 179 180 181 182 183
	gr, err := gzip.NewReader(bytes.NewReader(state))
	if err != nil {
		return false, err
	}

	dec := gob.NewDecoder(gr)
184
	var tqs masterState
185 186 187 188 189 190 191 192 193
	err = dec.Decode(&tqs)
	if err != nil {
		return false, err
	}

	err = gr.Close()
	if err != nil {
		// Only close failed, recover actually succeed, so
		// just log error.
194
		log.Error("error close recover file.", log.Ctx{"error": err})
195 196
	}

197
	s.state = tqs
198
	log.Info("Master recovered from snapshot, scheduling pending task timeout check.", s.logCtx())
199 200 201 202
	for _, t := range s.state.Pending {
		time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
	}

203
	return true, nil
204 205
}

206
// snapshot *must* be called with s.mu being held.
207
func (s *Service) snapshot() error {
208
	// TODO(helin): etcd request has a size limit, so the snapshot
209 210 211 212 213 214 215 216
	// size is limited by the max request size. We should either
	// divide the snapshot into smaller chunks and save under
	// different keys, or configure the request size to be big
	// enough:
	// https://github.com/coreos/etcd/blob/2f84f3d8d8ed8f9537ab6ffa44a3a1c7eddfa9b1/embed/config.go#L44
	var buf bytes.Buffer
	gw := gzip.NewWriter(&buf)
	enc := gob.NewEncoder(gw)
217
	err := enc.Encode(s.state)
218 219 220 221 222 223 224 225 226
	if err != nil {
		return err
	}
	err = gw.Close()
	if err != nil {
		return err
	}

	state := buf.Bytes()
227
	log.Info("Saving snapshot.", log.Ctx{"size bytes": len(state)})
228
	return s.store.Save(state)
229 230
}

H
Helin Wang 已提交
231
func readChunks(globPaths []string) ([]Chunk, error) {
232 233 234 235 236 237
	var chunks []Chunk
	var paths []string

	for _, s := range globPaths {
		match, err := filepath.Glob(s)
		if err != nil {
H
Helin Wang 已提交
238
			return nil, err
239 240 241 242 243
		}
		paths = append(paths, match...)
	}

	if len(paths) == 0 {
H
Helin Wang 已提交
244
		return nil, errors.New("no valid dataset specified")
245 246 247 248 249
	}

	for _, path := range paths {
		f, err := os.Open(path)
		if err != nil {
H
Helin Wang 已提交
250
			return nil, err
251 252 253 254
		}

		index, err := recordio.LoadIndex(f)
		if err != nil {
H
Helin Wang 已提交
255
			return nil, err
256 257 258
		}
		err = f.Close()
		if err != nil {
H
Helin Wang 已提交
259
			return nil, err
260 261 262
		}

		count := index.NumChunks()
263
		log.Info("reading chunks.", log.Ctx{"path": path, "num chunks": count})
264 265 266 267 268 269 270 271 272
		for i := 0; i < count; i++ {
			chunk := Chunk{
				Path:  path,
				Index: *index.ChunkIndex(i),
			}
			chunks = append(chunks, chunk)
		}
	}

H
Helin Wang 已提交
273 274 275 276 277 278 279
	return chunks, nil
}

// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
280
func (s *Service) SetDataset(globPaths []string, _ *int) error {
H
Helin Wang 已提交
281 282 283 284 285 286 287 288 289 290 291 292 293
	if len(globPaths) == 0 {
		return errors.New("no dataset specified")
	}

	s.mu.Lock()
	defer s.mu.Unlock()
	if s.initDone {
		// Already initialized. All trainer will call
		// SetDataset, but we only handle the first one. Treat
		// other calls as successful but do nothing.
		return nil
	}

H
Helin Wang 已提交
294
	chunks, err := readChunks(globPaths)
H
Helin Wang 已提交
295 296 297 298
	if err != nil {
		return err
	}

299
	s.state.Todo = partition(chunks, s.chunksPerTask)
300

H
Helin Wang 已提交
301
	err = s.snapshot()
302
	if err != nil {
303
		log.Error("snapshot error", log.Ctx{"error": err})
304 305 306
		return err
	}
	close(s.ready)
H
Helin Wang 已提交
307
	s.initDone = true
308 309 310
	return nil
}

311 312
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
G
gongweibao 已提交
313 314
func (s *Service) processFailedTask(t taskEntry, epoch int) {
	if t.Task.Meta.Epoch != epoch {
G
gongweibao 已提交
315 316 317 318 319 320 321 322
		// new epoch, task launched after the
		// schedule of this timeout check or failed status report.
		return
	}

	defer func() {
		err := s.snapshot()
		if err != nil {
323
			log.Error("snapshot error", log.Ctx{"error": err})
G
gongweibao 已提交
324 325 326
		}
	}()

327
	delete(s.state.Pending, t.Task.Meta.ID)
G
gongweibao 已提交
328

G
gongweibao 已提交
329 330
	t.NumFailure++
	if t.NumFailure > s.failureMax {
331
		log.Warn("Task failed to many times, discard.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
332
		s.state.Failed = append(s.state.Failed, t)
G
gongweibao 已提交
333 334 335
		return
	}

336
	log.Warn("Task failed, re-dispatch.", log.Ctx{"task": t.Task, "num failed": t.NumFailure})
337
	s.state.Todo = append(s.state.Todo, t)
338
	return
G
gongweibao 已提交
339 340
}

H
Helin Wang 已提交
341 342 343 344 345
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
	return func() {
		s.mu.Lock()
		defer s.mu.Unlock()

346
		t, ok := s.state.Pending[taskID]
H
Helin Wang 已提交
347 348 349 350
		if !ok {
			return
		}

G
gongweibao 已提交
351
		s.processFailedTask(t, epoch)
H
Helin Wang 已提交
352 353 354
	}
}

H
Helin Wang 已提交
355
// must be called with lock held.
356 357
func (s *Service) logCtx() log.Ctx {
	return log.Ctx{
358 359 360 361
		"todoLen":    len(s.state.Todo),
		"pendingLen": len(s.state.Pending),
		"doneLen":    len(s.state.Done),
		"failedLen":  len(s.state.Failed),
362
		"curPass":    s.state.CurPass,
H
Helin Wang 已提交
363 364 365
	}
}

366
// GetTask gets a new task from the service.
367 368
// passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
369 370 371 372
	select {
	case <-s.ready:
	}

373 374
	s.mu.Lock()
	defer s.mu.Unlock()
375
	if passID < s.state.CurPass {
376 377
		return ErrPassBefore
	}
378
	if passID > s.state.CurPass {
379 380 381 382
		// Client may get run to pass after master when one client faster than the
		// other
		return ErrPassAfter
	}
383

384 385
	if len(s.state.Todo) == 0 {
		if len(s.state.Done) == 0 && len(s.state.Pending) == 0 {
386
			log.Warn("All tasks failed, may start next pass", s.logCtx())
387
			return ErrAllTaskFailed
388
		}
389
		log.Warn("No more available task.", s.logCtx())
390
		return ErrNoMoreAvailable
391 392
	}

393
	t := s.state.Todo[0]
G
gongweibao 已提交
394
	t.Task.Meta.Epoch++
395 396
	s.state.Todo = s.state.Todo[1:]
	s.state.Pending[t.Task.Meta.ID] = t
397 398 399 400 401
	err := s.snapshot()
	if err != nil {
		return err
	}

402
	*task = t.Task
403 404 405
	ctx := s.logCtx()
	ctx["task meta"] = t.Task.Meta
	log.Info("Task dispatched.", ctx)
G
gongweibao 已提交
406
	time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.Meta.ID, t.Task.Meta.Epoch))
407 408 409 410
	return nil
}

// TaskFinished tell the service that a task is finished.
411
func (s *Service) TaskFinished(taskID int, dummy *int) error {
412 413 414 415
	select {
	case <-s.ready:
	}

416 417 418
	s.mu.Lock()
	defer s.mu.Unlock()

419
	t, ok := s.state.Pending[taskID]
420
	if !ok {
421 422 423
		ctx := s.logCtx()
		ctx["task id"] = taskID
		log.Warn("Pending task not found.", ctx)
G
gongweibao 已提交
424
		return nil
425 426 427
	}

	// task finished, reset timeout
G
gongweibao 已提交
428
	t.NumFailure = 0
429 430
	s.state.Done = append(s.state.Done, t)
	delete(s.state.Pending, taskID)
431

432 433 434
	ctx := s.logCtx()
	ctx["task id"] = taskID
	log.Info("Task finished.", ctx)
435
	if len(s.state.Todo) == 0 && len(s.state.Pending) == 0 {
436
		// increase master side pass count if all tasks finished
437
		s.state.CurPass++
438
		s.state.Todo = append(s.state.Done, s.state.Failed...)
439
		s.state.Done = []taskEntry{}
440
		// TODO(typhoonzero): deal with failed tasks
441
		s.state.Failed = []taskEntry{}
442 443 444
		ctx := s.logCtx()
		ctx["new pass"] = s.state.CurPass
		log.Warn("all task finished, add new pass data.", ctx)
445 446
	}

H
Helin Wang 已提交
447 448
	err := s.snapshot()
	if err != nil {
449
		log.Error("snapshot error", log.Ctx{"error": err})
H
Helin Wang 已提交
450 451
	}
	return err
452
}
G
gongweibao 已提交
453

G
gongweibao 已提交
454
// TaskFailed tells the service that a task is failed.
455
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
G
gongweibao 已提交
456 457 458 459 460 461 462
	select {
	case <-s.ready:
	}

	s.mu.Lock()
	defer s.mu.Unlock()

463
	t, ok := s.state.Pending[meta.ID]
G
gongweibao 已提交
464
	if !ok {
465
		log.Warn("TaskFailed:Pending task not found.", log.Ctx{"task": t.Task.Meta})
G
gongweibao 已提交
466
		return nil
G
gongweibao 已提交
467 468
	}

G
gongweibao 已提交
469
	s.processFailedTask(t, meta.Epoch)
G
gongweibao 已提交
470 471
	return nil
}
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510

// SaveModelRequest is the request for saving model
type SaveModelRequest struct {
	TrainerID string
	BlockDur  time.Duration
}

// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (s *Service) RequestSaveModel(req SaveModelRequest, need *bool) error {
	s.mu.Lock()
	defer s.mu.Unlock()

	if req.TrainerID == "" {
		return errors.New("trainer id is empty")
	}

	if s.savingTrainer == "" {
		*need = true
	} else {
		if req.TrainerID == s.savingTrainer {
			// save trainer asked to save model again
			*need = true
		} else {
			*need = false
		}
	}

	if *need {
		s.savingTrainer = req.TrainerID
		time.AfterFunc(req.BlockDur, func() {
			s.mu.Lock()
			s.savingTrainer = ""
			s.mu.Unlock()
		})
	}

	return nil
}