service.go 3.5 KB
Newer Older
1 2 3 4 5 6 7 8
package master

import (
	"errors"
	"log"
	"sync"
	"time"

H
Helin Wang 已提交
9
	"github.com/PaddlePaddle/Paddle/paddle/go/recordio"
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
)

const (
	targetTaskCount = 300
)

// errors
var (
	ErrNoMoreTask          = errors.New("no more task for current pass")
	ErrPendingTaskNotFound = errors.New("pending task not found")
)

// Service is the master server service.
type Service struct {
	timeoutDur time.Duration
	timeoutMax int

	mu         sync.Mutex
	taskQueues taskQueues
}

// Recover recovers service state from etcd.
func Recover() (*Service, error) {
	// TODO(helin): recover from snapshot state from etcd.
	return nil, nil
}

func partition(chunks []Chunk, targetTaskCount int) []taskEntry {
	id := 0
	chunkPerTask := len(chunks) / targetTaskCount
	if chunkPerTask <= 0 {
		chunkPerTask = 1
	}

	var result []taskEntry
	var cur taskEntry
	for i, c := range chunks {
		if i%chunkPerTask == 0 && len(cur.Task.Chunks) > 0 {
			cur.Task.ID = id
			id++
			result = append(result, cur)
			cur.Task.Chunks = nil
		}

		cur.Task.Chunks = append(cur.Task.Chunks, c)
	}

	if len(cur.Task.Chunks) > 0 {
		cur.Task.ID = id
		id++
		result = append(result, cur)
	}

	return result
}

// NewService creates a new service.
H
Helin Wang 已提交
67
func NewService(chunks []Chunk, timeoutDur time.Duration, timeoutMax int) *Service {
68 69 70 71 72 73
	s := &Service{}
	s.timeoutDur = timeoutDur
	s.timeoutMax = timeoutMax
	s.taskQueues = taskQueues{}
	s.taskQueues.Pending = make(map[int]taskEntry)
	s.taskQueues.Todo = partition(chunks, targetTaskCount)
H
Helin Wang 已提交
74
	return s
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
}

// Chunk is a chunk of data consisted of several data instances.
type Chunk struct {
	Idx   int // index of the chunk within the file
	Path  string
	Index recordio.Index // block index
}

// Task is the basic unit of data instances assigned to trainers.
type Task struct {
	ID     int
	Chunks []Chunk
}

type taskEntry struct {
	Epoch      int
	NumTimeout int
	Task       Task
}

type taskQueues struct {
	Todo    []taskEntry
	Pending map[int]taskEntry // map from task ID to task entry
	Done    []taskEntry
	Failed  []Task
}

// *must* be called with s.mu being held.
func (s *Service) snapshot() error {
	// TODO(helin): snapshot state on etcd.
	return nil
}

// GetTask gets a new task from the service.
func (s *Service) GetTask(dummy int, task *Task) error {
	s.mu.Lock()
	defer s.mu.Unlock()

	if len(s.taskQueues.Todo) == 0 {
		return ErrNoMoreTask
	}

	t := s.taskQueues.Todo[0]
	t.Epoch++
	s.taskQueues.Todo = s.taskQueues.Todo[1:]
	s.taskQueues.Pending[t.Task.ID] = t
	err := s.snapshot()
	if err != nil {
		return err
	}

	time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() {
		return func() {
			s.mu.Lock()
			defer s.mu.Unlock()

			t, ok := s.taskQueues.Pending[taskID]
			if !ok {
				return
			}

			if t.Epoch != epoch {
				// new epoch, task launched after the
				// schedule of this timeout check.
				return
			}

			defer func() {
				err := s.snapshot()
				if err != nil {
					log.Println(err)
				}
			}()

			delete(s.taskQueues.Pending, t.Task.ID)

			t.NumTimeout++
			if t.NumTimeout > s.timeoutMax {
				s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
				return
			}

			s.taskQueues.Todo = append(s.taskQueues.Todo, t)
		}
	}(t.Task.ID, t.Epoch))
	return nil
}

// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error {
	s.mu.Lock()
	defer s.mu.Unlock()

	t, ok := s.taskQueues.Pending[taskID]
	if !ok {
		return ErrPendingTaskNotFound
	}

	// task finished, reset timeout
	t.NumTimeout = 0
	s.taskQueues.Done = append(s.taskQueues.Done, t)
	delete(s.taskQueues.Pending, taskID)
	return s.snapshot()
}