service.go 8.3 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package pserver

import (
D
dongzhihong 已提交
18 19 20 21
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
D
dongzhihong 已提交
22
	"encoding/hex"
D
dongzhihong 已提交
23
	"encoding/json"
24 25
	"errors"
	"fmt"
26
	"io/ioutil"
D
dongzhihong 已提交
27
	"os"
H
Helin Wang 已提交
28
	"path"
D
dongzhihong 已提交
29
	"strconv"
30
	"sync"
D
dongzhihong 已提交
31 32
	"time"

H
Helin Wang 已提交
33 34
	uuid "github.com/satori/go.uuid"

D
dongzhihong 已提交
35
	log "github.com/sirupsen/logrus"
36 37 38 39 40
)

// ElementType is the type of elements of a Parameter.
type ElementType int

41 42 43 44
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")

45
// RPC error message.
46
const (
H
Helin Wang 已提交
47 48 49
	AlreadyInitialized = "pserver already initialized"
	Uninitialized      = "pserver not fully initialized"
	WrongChecksum      = "checkpoint file checksum validation failed"
50
)
51

52
// Supported element types.
53 54 55 56 57 58 59 60 61 62 63 64 65
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
66
	Content     []byte
67 68 69 70 71 72 73 74
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

75
// checkpointMeta saves checkpoint metadata
D
dongzhihong 已提交
76 77
type checkpointMeta struct {
	UUID      string `json:"uuid"`
H
Helin Wang 已提交
78
	Path      string `json:"path"`
79 80
	MD5       string `json:"md5"`
	Timestamp int64  `json:"timestamp"`
D
dongzhihong 已提交
81 82
}

H
Helin Wang 已提交
83
// Checkpoint is the pserver shard persist in file.
84
type Checkpoint []parameterCheckpoint
D
dongzhihong 已提交
85

86
// Gradient is the gradient of the parameter.
D
dongzhihong 已提交
87
type Gradient Parameter
88

H
Helin Wang 已提交
89
// Service is the RPC service for pserver.
90
type Service struct {
D
dongzhihong 已提交
91 92
	initialized        chan struct{}
	idx                int
D
dongzhihong 已提交
93
	checkpointInterval time.Duration
D
dongzhihong 已提交
94 95
	checkpointPath     string
	client             *EtcdClient
H
Helin Wang 已提交
96 97 98

	mu     sync.Mutex
	optMap map[string]*optimizer
99 100
}

H
Helin Wang 已提交
101
// parameterCheckpoint saves parameter checkpoint.
102 103 104 105 106
type parameterCheckpoint struct {
	ParameterWithConfig
	State []byte
}

H
Helin Wang 已提交
107 108
func loadMeta(e *EtcdClient, idx int) (meta checkpointMeta, err error) {
	v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
109
	if err != nil {
H
Helin Wang 已提交
110
		return
111 112
	}

113
	if len(v) == 0 {
H
Helin Wang 已提交
114 115
		err = ErrCheckpointNotFound
		return
116 117
	}

H
Helin Wang 已提交
118 119
	if err = json.Unmarshal(v, &meta); err != nil {
		return
120 121
	}

H
Helin Wang 已提交
122 123 124 125 126 127 128
	return
}

// LoadCheckpoint loads checkpoint from file.
func LoadCheckpoint(e *EtcdClient, idx int) (Checkpoint, error) {
	cpMeta, err := loadMeta(e, idx)
	if err != nil {
129 130
		return nil, err
	}
H
Helin Wang 已提交
131 132

	content, err := ioutil.ReadFile(cpMeta.Path)
133 134 135 136
	if err != nil {
		return nil, err
	}

H
Helin Wang 已提交
137 138
	// TODO(helin): change MD5 to CRC since CRC is better for file
	// checksum in our use case (emphasize speed over security).
139 140 141
	h := md5.New()
	md5 := hex.EncodeToString(h.Sum(content))
	if md5 != cpMeta.MD5 {
H
Helin Wang 已提交
142
		return nil, errors.New(WrongChecksum)
143 144 145
	}

	dec := gob.NewDecoder(bytes.NewReader(content))
H
Helin Wang 已提交
146 147
	var cp Checkpoint
	if err = dec.Decode(&cp); err != nil {
148 149 150 151 152
		return nil, err
	}
	return cp, nil
}

W
wuyi05 已提交
153
// NewService creates a new service, will bypass etcd registration if no
154
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
D
dongzhihong 已提交
155
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
156
	s := &Service{
D
dongzhihong 已提交
157
		idx:                idx,
158
		checkpointInterval: interval,
D
dongzhihong 已提交
159 160
		checkpointPath:     path,
		client:             client,
161
	}
D
dongzhihong 已提交
162
	s.optMap = make(map[string]*optimizer)
163
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
164 165

	if cp != nil {
D
dongzhihong 已提交
166
		for _, item := range cp {
167 168 169 170 171
			p := ParameterWithConfig{
				Param:  item.Param,
				Config: item.Config,
			}
			s.optMap[p.Param.Name] = newOptimizer(p, item.State)
D
dongzhihong 已提交
172 173
		}
	}
W
wuyi05 已提交
174
	return s, nil
175 176
}

H
Helin Wang 已提交
177
// InitParam initializes a parameter.
178
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
179 180
	select {
	case <-s.initialized:
181
		return errors.New(AlreadyInitialized)
182 183 184 185 186 187 188 189 190 191 192
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dongzhihong 已提交
193
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
194 195 196
	return nil
}

H
Helin Wang 已提交
197 198
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
199
func (s *Service) FinishInitParams(_ int, _ *int) error {
200 201
	select {
	case <-s.initialized:
202
		return errors.New(AlreadyInitialized)
203 204 205 206
	default:
	}

	close(s.initialized)
H
Helin Wang 已提交
207 208 209 210 211 212 213 214 215
	go func() {
		t := time.Tick(s.checkpointInterval)
		for range t {
			err := s.checkpoint()
			if err != nil {
				log.Errorln(err)
			}
		}
	}()
216 217 218
	return nil
}

219
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
220
// optimization.
221
func (s *Service) SendGrad(g Gradient, _ *int) error {
222 223 224
	select {
	case <-s.initialized:
	default:
225
		return errors.New(Uninitialized)
226
	}
227 228

	s.mu.Lock()
H
Helin Wang 已提交
229
	defer s.mu.Unlock()
230

D
dzhwinter 已提交
231 232
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
233
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
234
	}
235

D
dongzhihong 已提交
236
	return o.UpdateParameter(g)
237 238
}

239 240
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
241 242
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
243
	defer s.mu.Unlock()
244

D
dongzhihong 已提交
245
	opt, ok := s.optMap[name]
246 247
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
248 249
	}

250 251 252 253 254 255
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
256
	// parameter content.
D
dongzhihong 已提交
257
	parameter.Name = name
D
dongzhihong 已提交
258
	parameter.ElementType = opt.elementType
259 260
	parameter.Content = opt.GetWeights()
	return nil
261 262
}

H
Helin Wang 已提交
263 264 265 266 267 268 269 270 271 272 273 274
func traceTime(start time.Time, name string) {
	elapsed := time.Since(start)
	log.Infof("%s took %v", name, elapsed)
}

// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func (s *Service) checkpoint() (err error) {
	log.Infoln("Begin save checkpoint.")
	defer traceTime(time.Now(), "save checkpoint")
D
dongzhihong 已提交
275

H
Helin Wang 已提交
276
	s.mu.Lock()
277
	cp := make([]parameterCheckpoint, len(s.optMap))
D
dongzhihong 已提交
278
	index := 0
H
Helin Wang 已提交
279 280
	// TODO(helin): write checkpoint incrementally to reduce memory
	// footprint during checkpoint.
D
dongzhihong 已提交
281
	for name, opt := range s.optMap {
282 283 284 285
		var pc parameterCheckpoint
		pc.Param.Name = name
		pc.Param.ElementType = opt.elementType
		pc.Param.Content = opt.GetWeights()
H
Helin Wang 已提交
286
		pc.Config = opt.config
D
dongzhihong 已提交
287 288 289 290
		pc.State = opt.GetStates()
		cp[index] = pc
		index++
	}
H
Helin Wang 已提交
291 292
	s.mu.Unlock()

D
dongzhihong 已提交
293 294
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
H
Helin Wang 已提交
295
	err = encoder.Encode(cp)
D
dongzhihong 已提交
296
	if err != nil {
H
Helin Wang 已提交
297
		return
D
dongzhihong 已提交
298 299
	}

300 301 302 303 304 305 306
	if _, err = os.Stat(s.checkpointPath); os.IsNotExist(err) {
		err = os.MkdirAll(s.checkpointPath, os.ModePerm)
		if err != nil {
			return
		}
	}

H
Helin Wang 已提交
307 308 309
	id := uuid.NewV4().String()
	p := path.Join(s.checkpointPath, id)
	f, err := os.Create(p)
D
dongzhihong 已提交
310
	if err != nil {
H
Helin Wang 已提交
311
		return
D
dongzhihong 已提交
312
	}
H
Helin Wang 已提交
313 314 315 316 317 318 319 320 321 322 323 324 325

	defer func() {
		closeErr := f.Close()
		if closeErr != nil {
			if err != nil {
				log.Errorln(closeErr)
			} else {
				// Set closeErr as return value.
				err = closeErr
			}
		}
	}()

D
dongzhihong 已提交
326 327 328
	writer := bufio.NewWriter(f)
	_, err = writer.Write(buf.Bytes())
	if err != nil {
H
Helin Wang 已提交
329
		return
D
dongzhihong 已提交
330
	}
H
Helin Wang 已提交
331 332 333 334 335 336

	err = writer.Flush()
	if err != nil {
		return
	}

H
Helin Wang 已提交
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
	oldMeta, err := loadMeta(s.client, s.idx)
	if err == ErrCheckpointNotFound {
		log.Infoln("Do not have existing checkpoint.")
		err = nil
	}

	if err != nil {
		return
	}

	h := md5.New()
	md5 := hex.EncodeToString(h.Sum(buf.Bytes()))
	cpMeta := checkpointMeta{
		UUID:      id,
		Timestamp: time.Now().UnixNano(),
		MD5:       md5,
		Path:      p,
	}

	json, err := json.Marshal(cpMeta)
	if err != nil {
		return
	}

	err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false)
	if err != nil {
		return
	}

	if oldMeta.Path != "" {
		rmErr := os.Remove(oldMeta.Path)
		if rmErr != nil {
			// log error, but still treat checkpoint as
			// successful.
			log.Errorln(rmErr)
		}
	}

H
Helin Wang 已提交
375
	return
376
}