service.go 6.4 KB
Newer Older
1 2 3
package pserver

import (
D
dongzhihong 已提交
4 5 6 7
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
D
dongzhihong 已提交
8
	"encoding/hex"
D
dongzhihong 已提交
9
	"encoding/json"
10 11
	"errors"
	"fmt"
12
	"io/ioutil"
D
dongzhihong 已提交
13
	"os"
D
dongzhihong 已提交
14
	"path/filepath"
D
dongzhihong 已提交
15
	"strconv"
16
	"sync"
D
dongzhihong 已提交
17 18 19
	"time"

	log "github.com/sirupsen/logrus"
20 21 22 23 24
)

// ElementType is the type of elements of a Parameter.
type ElementType int

25
// RPC error message.
26
const (
27 28 29
	AlreadyInitialized  = "pserver already initialized"
	Uninitialized       = "pserver not fully initialized"
	CheckpointMD5Failed = "checkpoint file MD5 validation failed"
30
)
31

32
// Supported element types.
33 34 35 36 37 38 39 40 41 42 43 44 45
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
46
	Content     []byte
47 48 49 50 51 52 53 54
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

55
// checkpointMeta saves checkpoint metadata
D
dongzhihong 已提交
56 57
type checkpointMeta struct {
	UUID      string `json:"uuid"`
58 59
	MD5       string `json:"md5"`
	Timestamp int64  `json:"timestamp"`
D
dongzhihong 已提交
60 61 62
}

// Checkpoint is the pserver shard persist in file
63
type Checkpoint []parameterCheckpoint
D
dongzhihong 已提交
64

65
// Gradient is the gradient of the parameter.
D
dongzhihong 已提交
66
type Gradient Parameter
67

H
Helin Wang 已提交
68
// Service is the RPC service for pserver.
69
type Service struct {
D
dongzhihong 已提交
70 71
	initialized        chan struct{}
	idx                int
D
dongzhihong 已提交
72
	checkpointInterval time.Duration
D
dongzhihong 已提交
73 74 75 76
	checkpointPath     string
	client             *EtcdClient
	mu                 sync.Mutex
	optMap             map[string]*optimizer
77 78
}

79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
// parameterCheckpoint saves parameter checkpoint
type parameterCheckpoint struct {
	ParameterWithConfig
	State []byte
}

// NewCheckpointFromFile loads parameters and state from checkpoint file
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (*Checkpoint, error) {
	v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
	if err != nil {
		return nil, err
	}

	var cpMeta checkpointMeta
	if err = json.Unmarshal(v, &cpMeta); err != nil {
		return nil, err
	}

	fn := filepath.Join(cpPath, cpMeta.UUID)
	if _, err = os.Stat(fn); os.IsNotExist(err) {
		return nil, err
	}
	content, err := ioutil.ReadFile(fn)
	if err != nil {
		return nil, err
	}

	h := md5.New()
	md5 := hex.EncodeToString(h.Sum(content))
	if md5 != cpMeta.MD5 {
		return nil, errors.New(CheckpointMD5Failed)
	}

	dec := gob.NewDecoder(bytes.NewReader(content))
	cp := &Checkpoint{}
	if err = dec.Decode(cp); err != nil {
		return nil, err
	}
	return cp, nil
}

W
wuyi05 已提交
120
// NewService creates a new service, will bypass etcd registration if no
121 122
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp *Checkpoint) (*Service, error) {
123
	s := &Service{
D
dongzhihong 已提交
124
		idx:                idx,
125
		checkpointInterval: interval,
D
dongzhihong 已提交
126 127
		checkpointPath:     path,
		client:             client,
128
	}
D
dongzhihong 已提交
129
	s.optMap = make(map[string]*optimizer)
130
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
131 132

	if cp != nil {
133 134 135 136 137 138
		for _, item := range *cp {
			p := ParameterWithConfig{
				Param:  item.Param,
				Config: item.Config,
			}
			s.optMap[p.Param.Name] = newOptimizer(p, item.State)
D
dongzhihong 已提交
139 140
		}
	}
W
wuyi05 已提交
141
	return s, nil
142 143
}

H
Helin Wang 已提交
144
// InitParam initializes a parameter.
145 146 147
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
148
		return errors.New(AlreadyInitialized)
149 150 151 152 153 154 155 156 157 158 159
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dongzhihong 已提交
160
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
161 162 163
	return nil
}

H
Helin Wang 已提交
164 165
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
166 167 168
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
169
		return errors.New(AlreadyInitialized)
170 171 172 173 174 175 176
	default:
	}

	close(s.initialized)
	return nil
}

177
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
178
// optimization.
179
func (s *Service) SendGrad(g Gradient, dummy *int) error {
180 181 182
	select {
	case <-s.initialized:
	default:
183
		return errors.New(Uninitialized)
184
	}
185 186

	s.mu.Lock()
H
Helin Wang 已提交
187
	defer s.mu.Unlock()
188

D
dzhwinter 已提交
189 190
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
191
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
192
	}
193

D
dongzhihong 已提交
194
	return o.UpdateParameter(g)
195 196
}

197 198
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
199 200
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
201
	defer s.mu.Unlock()
202

D
dongzhihong 已提交
203
	opt, ok := s.optMap[name]
204 205
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
206 207
	}

208 209 210 211 212 213 214
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
	// paramter content.
D
dongzhihong 已提交
215
	parameter.Name = name
D
dongzhihong 已提交
216
	parameter.ElementType = opt.elementType
217 218
	parameter.Content = opt.GetWeights()
	return nil
219 220
}

D
dongzhihong 已提交
221 222
// pserver save checkpoint
func (s *Service) doCheckpoint() error {
223
	<-s.initialized
D
dongzhihong 已提交
224 225
	s.mu.Lock()
	defer s.mu.Unlock()
D
dongzhihong 已提交
226

227
	cp := make([]parameterCheckpoint, len(s.optMap))
D
dongzhihong 已提交
228
	index := 0
D
dongzhihong 已提交
229
	for name, opt := range s.optMap {
230 231 232 233
		var pc parameterCheckpoint
		pc.Param.Name = name
		pc.Param.ElementType = opt.elementType
		pc.Param.Content = opt.GetWeights()
D
dongzhihong 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246
		pc.State = opt.GetStates()
		cp[index] = pc
		index++
	}
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
	err := encoder.Encode(cp)
	if err != nil {
		return err
	}

	cpMeta := checkpointMeta{}
	cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
247
	cpMeta.Timestamp = time.Now().UnixNano()
D
dongzhihong 已提交
248
	h := md5.New()
249
	cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
D
dongzhihong 已提交
250

251
	cpMetajson, _ := json.Marshal(cpMeta)
252
	err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
D
dongzhihong 已提交
253 254 255 256 257 258 259
	if err != nil {
		return err
	}
	if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
		log.Info("checkpoint does not exists.")
	} else {
		err = os.Remove(cpMeta.UUID)
260 261 262 263 264
		if err != nil {
			log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
		} else {
			log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
		}
D
dongzhihong 已提交
265 266 267 268
	}
	f, err := os.Create(cpMeta.UUID)
	defer f.Close()
	if err != nil {
269
		return err
D
dongzhihong 已提交
270 271 272 273 274
	}
	writer := bufio.NewWriter(f)
	_, err = writer.Write(buf.Bytes())
	writer.Flush()
	if err != nil {
275
		return err
D
dongzhihong 已提交
276
	}
277 278
	return nil
}