service.go 5.7 KB
Newer Older
1 2 3
package pserver

import (
D
dongzhihong 已提交
4 5 6 7
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
D
dongzhihong 已提交
8
	"encoding/json"
9 10
	"errors"
	"fmt"
D
dongzhihong 已提交
11
	"os"
D
dongzhihong 已提交
12
	"path/filepath"
D
dongzhihong 已提交
13
	"strconv"
14
	"sync"
D
dongzhihong 已提交
15 16 17
	"time"

	log "github.com/sirupsen/logrus"
18 19 20 21 22
)

// ElementType is the type of elements of a Parameter.
type ElementType int

23
const (
W
wuyi05 已提交
24
	// AlreadyInitialized is true if pserver is initialized
25
	AlreadyInitialized = "pserver already initialized"
W
wuyi05 已提交
26 27
	// Uninitialized is true if pserver not fully initialized
	Uninitialized = "pserver not fully initialized"
28
)
29 30 31 32 33 34 35 36 37 38 39 40 41 42 43

// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
44
	Content     []byte
45 46 47 48 49 50 51 52
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

D
dongzhihong 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
// Checkpoint of Parameter and State
type parameterCheckPoint struct {
	ParamConfig ParameterWithConfig
	State       []byte
}

// checkpoint signature
type checkpointMeta struct {
	UUID      string `json:"uuid"`
	Md5sum    string `json:"md5sum"`
	Timestamp string `json:"timestamp"`
}

// Checkpoint is the pserver shard persist in file
type Checkpoint []parameterCheckPoint

69 70
// Gradient is the gradient of the parameter.

H
Helin Wang 已提交
71
// Service is the RPC service for pserver.
72
type Service struct {
D
dongzhihong 已提交
73 74 75 76 77 78 79
	initialized        chan struct{}
	idx                int
	checkpointInterval int
	checkpointPath     string
	client             *EtcdClient
	mu                 sync.Mutex
	optMap             map[string]*optimizer
80 81
}

D
dongzhihong 已提交
82 83
// //serialize ParameterWithConfig to byte stream
// func GetBytes(content ...interface{}) ([]byte, error) {
D
dongzhihong 已提交
84

D
dongzhihong 已提交
85 86 87 88 89 90 91 92
// 	var buf bytes.Buffer
// 	encoder := gob.NewEncoder(&buf)
// 	err := encoder.Encode(content)
// 	if err != nil {
// 		return nil, err
// 	}
// 	return buf.Bytes(), nil
// }
D
dongzhihong 已提交
93

W
wuyi05 已提交
94 95
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
D
dongzhihong 已提交
96
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
97
	s := &Service{
D
dongzhihong 已提交
98 99 100 101
		idx:                idx,
		checkpointInterval: time.Second * time.Duration(seconds),
		checkpointPath:     path,
		client:             client,
102
	}
D
dongzhihong 已提交
103
	s.optMap = make(map[string]*optimizer)
104
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
105 106 107 108 109 110 111 112

	if cp != nil {
		for _, item := range cp {
			p := item.ParamConfig
			st := item.State
			s.optMap[p.Param.Name] = newOptimizer(p, st)
		}
	}
W
wuyi05 已提交
113
	return s, nil
114 115
}

H
Helin Wang 已提交
116
// InitParam initializes a parameter.
117 118 119
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
120
		return errors.New(AlreadyInitialized)
121 122 123 124 125 126 127 128 129 130 131
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dzhwinter 已提交
132
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
133 134 135
	return nil
}

H
Helin Wang 已提交
136 137
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
138 139 140
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
141
		return errors.New(AlreadyInitialized)
142 143 144 145 146 147 148
	default:
	}

	close(s.initialized)
	return nil
}

149
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
150
// optimization.
151
func (s *Service) SendGrad(g Gradient, dummy *int) error {
152 153 154
	select {
	case <-s.initialized:
	default:
155
		return errors.New(Uninitialized)
156
	}
157 158

	s.mu.Lock()
H
Helin Wang 已提交
159
	defer s.mu.Unlock()
160

D
dzhwinter 已提交
161 162
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
163
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
164
	}
165

D
dongzhihong 已提交
166
	return o.UpdateParameter(g)
167 168
}

169 170
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
171 172
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
173
	defer s.mu.Unlock()
174

D
dongzhihong 已提交
175
	opt, ok := s.optMap[name]
176 177
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
178 179
	}

180 181 182 183 184 185 186
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
	// paramter content.
D
dongzhihong 已提交
187
	parameter.Name = name
D
dongzhihong 已提交
188
	parameter.ElementType = opt.elementType
189 190
	parameter.Content = opt.GetWeights()
	return nil
191 192
}

D
dongzhihong 已提交
193 194
// pserver save checkpoint
func (s *Service) doCheckpoint() error {
195
	<-s.initialized
D
dongzhihong 已提交
196 197
	s.mu.Lock()
	defer s.mu.Unlock()
D
dongzhihong 已提交
198 199 200

	cp := make([]parameterCheckPoint, 0, len(s.optMap))
	index := 0
D
dongzhihong 已提交
201
	for name, opt := range s.optMap {
D
dongzhihong 已提交
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
		var pc parameterCheckPoint
		pc.ParamConfig.Param.Name = name
		pc.ParamConfig.Param.ElementType = opt.elementType
		pc.ParamConfig.Param.Content = opt.GetWeights()
		pc.State = opt.GetStates()
		cp[index] = pc
		index++
	}
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
	err := encoder.Encode(cp)
	if err != nil {
		return err
	}

	cpMeta := checkpointMeta{}
	cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
	cpMeta.Timestamp = time.Now().String()
	h := md5.New()
	cpMeta.Md5sum = h.Sum(buf.Bytes())

	cpMetajson, err := json.Marshal(cpMeta)
	s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
	if err != nil {
		return err
	}
	if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
		log.Info("checkpoint does not exists.")
	} else {
		err = os.Remove(cpMeta.UUID)
		log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
	}
	f, err := os.Create(cpMeta.UUID)
	defer f.Close()
	if err != nil {
		log.Errorln(err)
	}
	writer := bufio.NewWriter(f)
	_, err = writer.Write(buf.Bytes())
	writer.Flush()
	if err != nil {
		log.Errorln(err)
D
dongzhihong 已提交
244
	}
245 246
	return nil
}