service.go 9.5 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package pserver

import (
D
dongzhihong 已提交
18 19 20
	"bufio"
	"bytes"
	"encoding/gob"
D
dongzhihong 已提交
21
	"encoding/json"
22 23
	"errors"
	"fmt"
H
Helin Wang 已提交
24
	"hash/crc32"
25
	"io/ioutil"
D
dongzhihong 已提交
26
	"os"
H
Helin Wang 已提交
27
	"path"
D
dongzhihong 已提交
28
	"strconv"
29
	"sync"
D
dongzhihong 已提交
30 31
	"time"

H
Helin Wang 已提交
32 33
	uuid "github.com/satori/go.uuid"

34
	log "github.com/inconshreveable/log15"
35 36 37 38 39
)

// ElementType is the type of elements of a Parameter.
type ElementType int

40 41
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
H
Helin Wang 已提交
42
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
43

44
// RPC error message.
45
const (
H
Helin Wang 已提交
46 47 48
	AlreadyInitialized = "pserver already initialized"
	Uninitialized      = "pserver not fully initialized"
	WrongChecksum      = "checkpoint file checksum validation failed"
49
)
50

51
// Supported element types.
52 53 54 55 56 57 58 59 60 61 62 63 64
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
65
	Content     []byte
66 67 68 69 70 71 72 73
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

74
// checkpointMeta saves checkpoint metadata
D
dongzhihong 已提交
75 76
type checkpointMeta struct {
	UUID      string `json:"uuid"`
H
Helin Wang 已提交
77
	Path      string `json:"path"`
H
Helin Wang 已提交
78
	CRC32     uint32 `json:"crc32"`
79
	Timestamp int64  `json:"timestamp"`
D
dongzhihong 已提交
80 81
}

H
Helin Wang 已提交
82
// Checkpoint is the pserver shard persist in file.
83
type Checkpoint []parameterCheckpoint
D
dongzhihong 已提交
84

85
// Gradient is the gradient of the parameter.
D
dongzhihong 已提交
86
type Gradient Parameter
87

H
Helin Wang 已提交
88
// Service is the RPC service for pserver.
89
type Service struct {
D
dongzhihong 已提交
90 91
	initialized        chan struct{}
	idx                int
D
dongzhihong 已提交
92
	checkpointInterval time.Duration
D
dongzhihong 已提交
93
	checkpointPath     string
H
Helin Wang 已提交
94
	client             KVStore
H
Helin Wang 已提交
95 96 97

	mu     sync.Mutex
	optMap map[string]*optimizer
98 99
}

H
Helin Wang 已提交
100
// parameterCheckpoint saves parameter checkpoint.
101 102 103 104 105
type parameterCheckpoint struct {
	ParameterWithConfig
	State []byte
}

H
Helin Wang 已提交
106 107 108 109 110 111
type KVStore interface {
	GetKey(key string, timeout time.Duration) ([]byte, error)
	PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
}

func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
H
Helin Wang 已提交
112
	v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
113
	if err != nil {
H
Helin Wang 已提交
114
		return
115 116
	}

117
	if len(v) == 0 {
H
Helin Wang 已提交
118 119
		err = ErrCheckpointNotFound
		return
120 121
	}

H
Helin Wang 已提交
122 123
	if err = json.Unmarshal(v, &meta); err != nil {
		return
124 125
	}

H
Helin Wang 已提交
126 127 128 129
	return
}

// LoadCheckpoint loads checkpoint from file.
H
Helin Wang 已提交
130
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
H
Helin Wang 已提交
131 132 133
	log.Info("Loading checkpoint", "pserver index", idx)
	defer traceTime(time.Now(), "load checkpoint")

H
Helin Wang 已提交
134 135
	cpMeta, err := loadMeta(e, idx)
	if err != nil {
136 137
		return nil, err
	}
H
Helin Wang 已提交
138 139

	content, err := ioutil.ReadFile(cpMeta.Path)
140 141 142 143
	if err != nil {
		return nil, err
	}

H
Helin Wang 已提交
144 145
	crc32 := crc32.ChecksumIEEE(content)
	if crc32 != cpMeta.CRC32 {
H
Helin Wang 已提交
146
		return nil, errors.New(WrongChecksum)
147 148 149
	}

	dec := gob.NewDecoder(bytes.NewReader(content))
H
Helin Wang 已提交
150 151
	var cp Checkpoint
	if err = dec.Decode(&cp); err != nil {
152 153
		return nil, err
	}
H
Helin Wang 已提交
154

155 156 157
	return cp, nil
}

W
wuyi05 已提交
158
// NewService creates a new service, will bypass etcd registration if no
159
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
H
Helin Wang 已提交
160
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
161
	s := &Service{
D
dongzhihong 已提交
162
		idx:                idx,
163
		checkpointInterval: interval,
D
dongzhihong 已提交
164 165
		checkpointPath:     path,
		client:             client,
166
	}
D
dongzhihong 已提交
167
	s.optMap = make(map[string]*optimizer)
168
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
169 170

	if cp != nil {
D
dongzhihong 已提交
171
		for _, item := range cp {
172 173 174 175 176
			p := ParameterWithConfig{
				Param:  item.Param,
				Config: item.Config,
			}
			s.optMap[p.Param.Name] = newOptimizer(p, item.State)
D
dongzhihong 已提交
177
		}
H
Helin Wang 已提交
178
		close(s.initialized)
D
dongzhihong 已提交
179
	}
W
wuyi05 已提交
180
	return s, nil
181 182
}

H
Helin Wang 已提交
183
// InitParam initializes a parameter.
184
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
185 186
	select {
	case <-s.initialized:
H
Helin Wang 已提交
187
		log.Warn("init param called but parameters already initialized.")
188
		return errors.New(AlreadyInitialized)
189 190 191 192 193 194 195 196 197 198 199
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dongzhihong 已提交
200
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
H
Helin Wang 已提交
201 202 203 204 205 206 207
	log.Info(
		"init parameter",
		"name", paramWithConfigs.Param.Name,
		"config len", len(paramWithConfigs.Config),
		"param len", len(paramWithConfigs.Param.Content),
		"type", paramWithConfigs.Param.ElementType,
	)
208 209 210
	return nil
}

H
Helin Wang 已提交
211 212
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
213
func (s *Service) FinishInitParams(_ int, _ *int) error {
214 215
	select {
	case <-s.initialized:
H
Helin Wang 已提交
216
		log.Warn("finished init param called but parameters already initialized.")
217
		return errors.New(AlreadyInitialized)
218 219 220 221
	default:
	}

	close(s.initialized)
H
Helin Wang 已提交
222 223 224 225 226
	go func() {
		t := time.Tick(s.checkpointInterval)
		for range t {
			err := s.checkpoint()
			if err != nil {
H
Helin Wang 已提交
227
				log.Error("checkpoint error", log.Ctx{"error": err})
H
Helin Wang 已提交
228 229 230
			}
		}
	}()
H
Helin Wang 已提交
231 232

	log.Info("init parameter finished.")
233 234 235
	return nil
}

236
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
237
// optimization.
238
func (s *Service) SendGrad(g Gradient, _ *int) error {
239 240 241
	select {
	case <-s.initialized:
	default:
H
Helin Wang 已提交
242
		log.Warn("received gradient before initialization.", "name", g.Name, "size", len(g.Content), "type", g.ElementType)
243
		return errors.New(Uninitialized)
244
	}
245 246

	s.mu.Lock()
H
Helin Wang 已提交
247
	defer s.mu.Unlock()
248

D
dzhwinter 已提交
249 250
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
251
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
252
	}
253

H
Helin Wang 已提交
254
	log.Info("received gradient from trainer, updating gradient.", "name", g.Name, "size", len(g.Content), "type", g.ElementType)
D
dongzhihong 已提交
255
	return o.UpdateParameter(g)
256 257
}

258 259
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
260 261
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
262
	defer s.mu.Unlock()
263

D
dongzhihong 已提交
264
	opt, ok := s.optMap[name]
265
	if !ok {
H
Helin Wang 已提交
266
		log.Warn("trainer wants to get a parameter that does not exist.", "name", name)
267
		return fmt.Errorf("parameter: %s does not exist", name)
268 269
	}

270 271 272 273 274 275
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
276
	// parameter content.
D
dongzhihong 已提交
277
	parameter.Name = name
D
dongzhihong 已提交
278
	parameter.ElementType = opt.elementType
279
	parameter.Content = opt.GetWeights()
H
Helin Wang 已提交
280

H
Helin Wang 已提交
281
	log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
282
	return nil
283 284
}

H
Helin Wang 已提交
285 286
func traceTime(start time.Time, name string) {
	elapsed := time.Since(start)
287
	log.Info("time elapsed", log.Ctx{"name": name, "elapsed": elapsed})
H
Helin Wang 已提交
288 289 290 291 292 293 294
}

// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func (s *Service) checkpoint() (err error) {
295
	log.Info("Begin save checkpoint.")
H
Helin Wang 已提交
296
	defer traceTime(time.Now(), "save checkpoint")
D
dongzhihong 已提交
297

H
Helin Wang 已提交
298
	s.mu.Lock()
299
	cp := make([]parameterCheckpoint, len(s.optMap))
D
dongzhihong 已提交
300
	index := 0
H
Helin Wang 已提交
301 302
	// TODO(helin): write checkpoint incrementally to reduce memory
	// footprint during checkpoint.
D
dongzhihong 已提交
303
	for name, opt := range s.optMap {
304 305 306 307
		var pc parameterCheckpoint
		pc.Param.Name = name
		pc.Param.ElementType = opt.elementType
		pc.Param.Content = opt.GetWeights()
H
Helin Wang 已提交
308
		pc.Config = opt.config
D
dongzhihong 已提交
309 310 311 312
		pc.State = opt.GetStates()
		cp[index] = pc
		index++
	}
H
Helin Wang 已提交
313 314
	s.mu.Unlock()

D
dongzhihong 已提交
315 316
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
H
Helin Wang 已提交
317
	err = encoder.Encode(cp)
D
dongzhihong 已提交
318
	if err != nil {
H
Helin Wang 已提交
319
		return
D
dongzhihong 已提交
320 321
	}

322 323 324 325 326 327 328
	if _, err = os.Stat(s.checkpointPath); os.IsNotExist(err) {
		err = os.MkdirAll(s.checkpointPath, os.ModePerm)
		if err != nil {
			return
		}
	}

H
Helin Wang 已提交
329 330 331
	id := uuid.NewV4().String()
	p := path.Join(s.checkpointPath, id)
	f, err := os.Create(p)
D
dongzhihong 已提交
332
	if err != nil {
H
Helin Wang 已提交
333
		return
D
dongzhihong 已提交
334
	}
H
Helin Wang 已提交
335 336 337 338 339

	defer func() {
		closeErr := f.Close()
		if closeErr != nil {
			if err != nil {
340
				log.Error("error close checkpoint file", log.Ctx{"error": closeErr})
H
Helin Wang 已提交
341 342 343 344 345 346 347
			} else {
				// Set closeErr as return value.
				err = closeErr
			}
		}
	}()

D
dongzhihong 已提交
348 349 350
	writer := bufio.NewWriter(f)
	_, err = writer.Write(buf.Bytes())
	if err != nil {
H
Helin Wang 已提交
351
		return
D
dongzhihong 已提交
352
	}
H
Helin Wang 已提交
353 354 355 356 357 358

	err = writer.Flush()
	if err != nil {
		return
	}

H
Helin Wang 已提交
359 360
	oldMeta, err := loadMeta(s.client, s.idx)
	if err == ErrCheckpointNotFound {
H
Helin Wang 已提交
361
		log.Info("old meta not found, skip removing old meta")
H
Helin Wang 已提交
362
		err = nil
H
Helin Wang 已提交
363 364 365 366 367 368 369 370 371 372
	} else if err == nil {
		log.Info("removing old meta")
		if oldMeta.Path != "" {
			rmErr := os.Remove(oldMeta.Path)
			if rmErr != nil {
				// log error, but still treat checkpoint as
				// successful.
				log.Error("remove old meta file error", log.Ctx{"error": rmErr})
			}
		}
H
Helin Wang 已提交
373 374 375 376 377 378
	}

	if err != nil {
		return
	}

H
Helin Wang 已提交
379
	crc32 := crc32.ChecksumIEEE(buf.Bytes())
H
Helin Wang 已提交
380 381 382
	cpMeta := checkpointMeta{
		UUID:      id,
		Timestamp: time.Now().UnixNano(),
H
Helin Wang 已提交
383
		CRC32:     crc32,
H
Helin Wang 已提交
384 385 386 387 388 389 390 391 392 393 394 395 396
		Path:      p,
	}

	json, err := json.Marshal(cpMeta)
	if err != nil {
		return
	}

	err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false)
	if err != nil {
		return
	}

H
Helin Wang 已提交
397
	return
398
}