service.go 3.6 KB
Newer Older
1 2 3 4 5 6 7 8
package master

import (
	"errors"
	"log"
	"sync"
	"time"

9
	"github.com/PaddlePaddle/recordio"
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
)

const (
	targetTaskCount = 300
)

// errors
var (
	ErrNoMoreTask          = errors.New("no more task for current pass")
	ErrPendingTaskNotFound = errors.New("pending task not found")
)

// Service is the master server service.
type Service struct {
	timeoutDur time.Duration
	timeoutMax int

	mu         sync.Mutex
	taskQueues taskQueues
}

// Recover recovers service state from etcd.
func Recover() (*Service, error) {
	// TODO(helin): recover from snapshot state from etcd.
	return nil, nil
}

H
Helin Wang 已提交
37
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
38
	id := 0
H
Helin Wang 已提交
39 40
	if chunksPerTask <= 0 {
		chunksPerTask = 1
41 42 43 44 45
	}

	var result []taskEntry
	var cur taskEntry
	for i, c := range chunks {
H
Helin Wang 已提交
46
		if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
			cur.Task.ID = id
			id++
			result = append(result, cur)
			cur.Task.Chunks = nil
		}

		cur.Task.Chunks = append(cur.Task.Chunks, c)
	}

	if len(cur.Task.Chunks) > 0 {
		cur.Task.ID = id
		id++
		result = append(result, cur)
	}

	return result
}

// NewService creates a new service.
H
Helin Wang 已提交
66
func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service {
67 68 69 70 71
	s := &Service{}
	s.timeoutDur = timeoutDur
	s.timeoutMax = timeoutMax
	s.taskQueues = taskQueues{}
	s.taskQueues.Pending = make(map[int]taskEntry)
H
Helin Wang 已提交
72
	s.taskQueues.Todo = partition(chunks, chunksPerTask)
H
Helin Wang 已提交
73
	return s
74 75 76 77 78
}

// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
	Path  string
79
	Index recordio.Index // chunk index
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
}

// Task is the basic unit of data instances assigned to trainers.
type Task struct {
	ID     int
	Chunks []Chunk
}

type taskEntry struct {
	Epoch      int
	NumTimeout int
	Task       Task
}

type taskQueues struct {
	Todo    []taskEntry
	Pending map[int]taskEntry // map from task ID to task entry
	Done    []taskEntry
	Failed  []Task
}

// *must* be called with s.mu being held.
func (s *Service) snapshot() error {
	// TODO(helin): snapshot state on etcd.
	return nil
}

// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
	s.mu.Lock()
	defer s.mu.Unlock()

	if len(s.taskQueues.Todo) == 0 {
		return ErrNoMoreTask
	}

	t := s.taskQueues.Todo[0]
	t.Epoch++
	s.taskQueues.Todo = s.taskQueues.Todo[1:]
	s.taskQueues.Pending[t.Task.ID] = t
	err := s.snapshot()
	if err != nil {
		return err
	}

125 126
	*task = t.Task

127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
	time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() {
		return func() {
			s.mu.Lock()
			defer s.mu.Unlock()

			t, ok := s.taskQueues.Pending[taskID]
			if !ok {
				return
			}

			if t.Epoch != epoch {
				// new epoch, task launched after the
				// schedule of this timeout check.
				return
			}

			defer func() {
				err := s.snapshot()
				if err != nil {
					log.Println(err)
				}
			}()

			delete(s.taskQueues.Pending, t.Task.ID)

			t.NumTimeout++
			if t.NumTimeout > s.timeoutMax {
				s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
				return
			}

			s.taskQueues.Todo = append(s.taskQueues.Todo, t)
		}
	}(t.Task.ID, t.Epoch))
	return nil
}

// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
	s.mu.Lock()
	defer s.mu.Unlock()

	t, ok := s.taskQueues.Pending[taskID]
	if !ok {
		return ErrPendingTaskNotFound
	}

	// task finished, reset timeout
	t.NumTimeout = 0
	s.taskQueues.Done = append(s.taskQueues.Done, t)
	delete(s.taskQueues.Pending, taskID)
178 179 180 181 182 183

	if len(s.taskQueues.Todo) == 0 {
		s.taskQueues.Todo = s.taskQueues.Done
		s.taskQueues.Done = nil
	}

184 185
	return s.snapshot()
}