service.go 10.6 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package pserver

import (
D
dongzhihong 已提交
18 19
	"bufio"
	"bytes"
20
	"encoding/binary"
D
dongzhihong 已提交
21
	"encoding/gob"
D
dongzhihong 已提交
22
	"encoding/json"
23 24
	"errors"
	"fmt"
H
Helin Wang 已提交
25
	"hash/crc32"
26
	"io/ioutil"
D
dongzhihong 已提交
27
	"os"
H
Helin Wang 已提交
28
	"path"
D
dongzhihong 已提交
29
	"strconv"
30
	"strings"
31
	"sync"
D
dongzhihong 已提交
32 33
	"time"

34
	"github.com/golang/protobuf/proto"
H
Helin Wang 已提交
35 36
	uuid "github.com/satori/go.uuid"

37 38
	pb "github.com/PaddlePaddle/Paddle/go/proto"

39
	log "github.com/inconshreveable/log15"
40 41 42 43 44
)

// ElementType is the type of elements of a Parameter.
type ElementType int

45 46
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
H
Helin Wang 已提交
47
var ErrCheckpointNotFound = errors.New("checkpoint not found in etcd")
48

49
// RPC error message.
50
const (
H
Helin Wang 已提交
51 52 53
	AlreadyInitialized = "pserver already initialized"
	Uninitialized      = "pserver not fully initialized"
	WrongChecksum      = "checkpoint file checksum validation failed"
54
)
55

56
// Supported element types.
57 58 59 60 61 62 63 64 65 66 67 68 69
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
70
	Content     []byte
71 72
}

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
func float32ToString(b []byte) string {
	f := make([]float32, len(b)/4)
	buf := bytes.NewReader(b)
	err := binary.Read(buf, binary.LittleEndian, &f)
	if err != nil {
		return ""
	}
	return fmt.Sprintf("%v", f)
}

func float32ByteToString(c []byte) string {
	var a []byte
	var b []byte
	if len(c) <= 80 {
		a = c
	} else {
		a = c[0:40]
		b = c[len(c)-40:]
	}

	var s string
	s = float32ToString(a)

	if b == nil {
		return s
	}

	s = strings.Replace(s, "]", "", -1) + "..." + strings.Replace(float32ToString(b), "[", "", -1)
	return s
}

func (p Parameter) String() string {
	if p.ElementType != Float32 {
		return fmt.Sprintf("name:%v ElementType:%v",
			p.Name, p.ElementType)
	}

	return float32ByteToString(p.Content)
}

113 114 115 116 117 118
// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

119
// checkpointMeta saves checkpoint metadata
D
dongzhihong 已提交
120 121
type checkpointMeta struct {
	UUID      string `json:"uuid"`
H
Helin Wang 已提交
122
	Path      string `json:"path"`
H
Helin Wang 已提交
123
	CRC32     uint32 `json:"crc32"`
124
	Timestamp int64  `json:"timestamp"`
D
dongzhihong 已提交
125 126
}

H
Helin Wang 已提交
127
// Checkpoint is the pserver shard persist in file.
128
type Checkpoint []parameterCheckpoint
D
dongzhihong 已提交
129

130
// Gradient is the gradient of the parameter.
D
dongzhihong 已提交
131
type Gradient Parameter
132

H
Helin Wang 已提交
133
// Service is the RPC service for pserver.
134
type Service struct {
D
dongzhihong 已提交
135 136
	initialized        chan struct{}
	idx                int
D
dongzhihong 已提交
137
	checkpointInterval time.Duration
D
dongzhihong 已提交
138
	checkpointPath     string
H
Helin Wang 已提交
139
	client             KVStore
H
Helin Wang 已提交
140 141 142

	mu     sync.Mutex
	optMap map[string]*optimizer
143 144
}

H
Helin Wang 已提交
145
// parameterCheckpoint saves parameter checkpoint.
146 147 148 149 150
type parameterCheckpoint struct {
	ParameterWithConfig
	State []byte
}

H
Helin Wang 已提交
151 152 153 154 155 156
type KVStore interface {
	GetKey(key string, timeout time.Duration) ([]byte, error)
	PutKey(key string, value []byte, timeout time.Duration, withLease bool) error
}

func loadMeta(e KVStore, idx int) (meta checkpointMeta, err error) {
H
Helin Wang 已提交
157
	v, err := e.GetKey(PsCheckpoint+strconv.Itoa(idx), 3*time.Second)
158
	if err != nil {
H
Helin Wang 已提交
159
		return
160 161
	}

162
	if len(v) == 0 {
H
Helin Wang 已提交
163 164
		err = ErrCheckpointNotFound
		return
165 166
	}

H
Helin Wang 已提交
167 168
	if err = json.Unmarshal(v, &meta); err != nil {
		return
169 170
	}

H
Helin Wang 已提交
171 172 173 174
	return
}

// LoadCheckpoint loads checkpoint from file.
H
Helin Wang 已提交
175
func LoadCheckpoint(e KVStore, idx int) (Checkpoint, error) {
H
Helin Wang 已提交
176 177 178
	log.Info("Loading checkpoint", "pserver index", idx)
	defer traceTime(time.Now(), "load checkpoint")

H
Helin Wang 已提交
179 180
	cpMeta, err := loadMeta(e, idx)
	if err != nil {
181 182
		return nil, err
	}
H
Helin Wang 已提交
183 184

	content, err := ioutil.ReadFile(cpMeta.Path)
185 186 187 188
	if err != nil {
		return nil, err
	}

H
Helin Wang 已提交
189 190
	crc32 := crc32.ChecksumIEEE(content)
	if crc32 != cpMeta.CRC32 {
H
Helin Wang 已提交
191
		return nil, errors.New(WrongChecksum)
192 193 194
	}

	dec := gob.NewDecoder(bytes.NewReader(content))
H
Helin Wang 已提交
195 196
	var cp Checkpoint
	if err = dec.Decode(&cp); err != nil {
197 198
		return nil, err
	}
H
Helin Wang 已提交
199

200 201 202
	return cp, nil
}

W
wuyi05 已提交
203
// NewService creates a new service, will bypass etcd registration if no
204
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
H
Helin Wang 已提交
205
func NewService(idx int, interval time.Duration, path string, client KVStore, cp Checkpoint) (*Service, error) {
206
	s := &Service{
D
dongzhihong 已提交
207
		idx:                idx,
208
		checkpointInterval: interval,
D
dongzhihong 已提交
209 210
		checkpointPath:     path,
		client:             client,
211
	}
D
dongzhihong 已提交
212
	s.optMap = make(map[string]*optimizer)
213
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
214 215

	if cp != nil {
D
dongzhihong 已提交
216
		for _, item := range cp {
217 218 219 220 221
			p := ParameterWithConfig{
				Param:  item.Param,
				Config: item.Config,
			}
			s.optMap[p.Param.Name] = newOptimizer(p, item.State)
D
dongzhihong 已提交
222
		}
H
Helin Wang 已提交
223
		close(s.initialized)
D
dongzhihong 已提交
224
	}
W
wuyi05 已提交
225
	return s, nil
226 227
}

H
Helin Wang 已提交
228
// InitParam initializes a parameter.
229
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, _ *int) error {
230 231
	select {
	case <-s.initialized:
H
Helin Wang 已提交
232
		log.Warn("init param called but parameters already initialized.")
233
		return errors.New(AlreadyInitialized)
234 235 236
	default:
	}

237 238 239
	c := &pb.OptimizerConfig{}
	proto.Unmarshal(paramWithConfigs.Config, c)
	log.Debug(fmt.Sprintf("OptimizerConfig:%v", c))
240 241 242 243 244 245 246

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dongzhihong 已提交
247
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
H
Helin Wang 已提交
248 249 250 251 252 253 254
	log.Info(
		"init parameter",
		"name", paramWithConfigs.Param.Name,
		"config len", len(paramWithConfigs.Config),
		"param len", len(paramWithConfigs.Param.Content),
		"type", paramWithConfigs.Param.ElementType,
	)
255 256 257
	return nil
}

H
Helin Wang 已提交
258 259
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
260
func (s *Service) FinishInitParams(_ int, _ *int) error {
261 262
	select {
	case <-s.initialized:
H
Helin Wang 已提交
263
		log.Warn("finished init param called but parameters already initialized.")
264
		return errors.New(AlreadyInitialized)
265 266 267 268
	default:
	}

	close(s.initialized)
H
Helin Wang 已提交
269 270 271 272 273
	go func() {
		t := time.Tick(s.checkpointInterval)
		for range t {
			err := s.checkpoint()
			if err != nil {
H
Helin Wang 已提交
274
				log.Error("checkpoint error", log.Ctx{"error": err})
H
Helin Wang 已提交
275 276 277
			}
		}
	}()
H
Helin Wang 已提交
278 279

	log.Info("init parameter finished.")
280 281 282
	return nil
}

283
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
284
// optimization.
285
func (s *Service) SendGrad(g Gradient, _ *int) error {
286 287 288
	select {
	case <-s.initialized:
	default:
289 290
		log.Warn("received gradient before initialization.",
			"name", g.Name, "size", len(g.Content), "type", g.ElementType)
291
		return errors.New(Uninitialized)
292
	}
293 294

	s.mu.Lock()
H
Helin Wang 已提交
295
	defer s.mu.Unlock()
296

D
dzhwinter 已提交
297 298
	o, ok := s.optMap[g.Name]
	if !ok {
299 300
		log.Warn("received gradient but can't find name.",
			"name", g.Name, "size", len(g.Content), "type", g.ElementType)
D
dzhwinter 已提交
301
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
302
	}
303

304 305 306
	log.Debug(Parameter(g).String())
	log.Info("received gradient from trainer, updating gradient.",
		"name", g.Name, "size", len(g.Content), "type", g.ElementType)
D
dongzhihong 已提交
307
	return o.UpdateParameter(g)
308 309
}

310 311
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
312 313
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
314
	defer s.mu.Unlock()
315

D
dongzhihong 已提交
316
	opt, ok := s.optMap[name]
317
	if !ok {
H
Helin Wang 已提交
318
		log.Warn("trainer wants to get a parameter that does not exist.", "name", name)
319
		return fmt.Errorf("parameter: %s does not exist", name)
320 321
	}

322 323 324 325 326 327
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
328
	// parameter content.
D
dongzhihong 已提交
329
	parameter.Name = name
D
dongzhihong 已提交
330
	parameter.ElementType = opt.elementType
331
	parameter.Content = opt.GetWeights()
332
	log.Debug(parameter.String())
H
Helin Wang 已提交
333
	log.Info("sending parameter to the trainer", "name", parameter.Name, "size", len(parameter.Content), "type", parameter.ElementType)
334
	return nil
335 336
}

H
Helin Wang 已提交
337 338
func traceTime(start time.Time, name string) {
	elapsed := time.Since(start)
339
	log.Info("time elapsed", log.Ctx{"name": name, "elapsed": elapsed})
H
Helin Wang 已提交
340 341 342 343 344 345 346
}

// checkpoint saves checkpoint to disk.
//
// checkpoint should be only called after the parameters are
// initialized.
func (s *Service) checkpoint() (err error) {
347
	log.Info("Begin save checkpoint.")
H
Helin Wang 已提交
348
	defer traceTime(time.Now(), "save checkpoint")
D
dongzhihong 已提交
349

H
Helin Wang 已提交
350
	s.mu.Lock()
351
	cp := make([]parameterCheckpoint, len(s.optMap))
D
dongzhihong 已提交
352
	index := 0
H
Helin Wang 已提交
353 354
	// TODO(helin): write checkpoint incrementally to reduce memory
	// footprint during checkpoint.
D
dongzhihong 已提交
355
	for name, opt := range s.optMap {
356 357 358 359
		var pc parameterCheckpoint
		pc.Param.Name = name
		pc.Param.ElementType = opt.elementType
		pc.Param.Content = opt.GetWeights()
H
Helin Wang 已提交
360
		pc.Config = opt.config
D
dongzhihong 已提交
361 362 363 364
		pc.State = opt.GetStates()
		cp[index] = pc
		index++
	}
H
Helin Wang 已提交
365 366
	s.mu.Unlock()

D
dongzhihong 已提交
367 368
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
H
Helin Wang 已提交
369
	err = encoder.Encode(cp)
D
dongzhihong 已提交
370
	if err != nil {
H
Helin Wang 已提交
371
		return
D
dongzhihong 已提交
372 373
	}

374 375 376 377 378 379 380
	if _, err = os.Stat(s.checkpointPath); os.IsNotExist(err) {
		err = os.MkdirAll(s.checkpointPath, os.ModePerm)
		if err != nil {
			return
		}
	}

H
Helin Wang 已提交
381 382 383
	id := uuid.NewV4().String()
	p := path.Join(s.checkpointPath, id)
	f, err := os.Create(p)
D
dongzhihong 已提交
384
	if err != nil {
H
Helin Wang 已提交
385
		return
D
dongzhihong 已提交
386
	}
H
Helin Wang 已提交
387 388 389 390 391

	defer func() {
		closeErr := f.Close()
		if closeErr != nil {
			if err != nil {
392
				log.Error("error close checkpoint file", log.Ctx{"error": closeErr})
H
Helin Wang 已提交
393 394 395 396 397 398 399
			} else {
				// Set closeErr as return value.
				err = closeErr
			}
		}
	}()

D
dongzhihong 已提交
400 401 402
	writer := bufio.NewWriter(f)
	_, err = writer.Write(buf.Bytes())
	if err != nil {
H
Helin Wang 已提交
403
		return
D
dongzhihong 已提交
404
	}
H
Helin Wang 已提交
405 406 407 408 409 410

	err = writer.Flush()
	if err != nil {
		return
	}

H
Helin Wang 已提交
411 412
	oldMeta, err := loadMeta(s.client, s.idx)
	if err == ErrCheckpointNotFound {
H
Helin Wang 已提交
413
		log.Info("old meta not found, skip removing old meta")
H
Helin Wang 已提交
414
		err = nil
H
Helin Wang 已提交
415 416 417 418 419 420 421 422 423 424
	} else if err == nil {
		log.Info("removing old meta")
		if oldMeta.Path != "" {
			rmErr := os.Remove(oldMeta.Path)
			if rmErr != nil {
				// log error, but still treat checkpoint as
				// successful.
				log.Error("remove old meta file error", log.Ctx{"error": rmErr})
			}
		}
H
Helin Wang 已提交
425 426 427 428 429 430
	}

	if err != nil {
		return
	}

H
Helin Wang 已提交
431
	crc32 := crc32.ChecksumIEEE(buf.Bytes())
H
Helin Wang 已提交
432 433 434
	cpMeta := checkpointMeta{
		UUID:      id,
		Timestamp: time.Now().UnixNano(),
H
Helin Wang 已提交
435
		CRC32:     crc32,
H
Helin Wang 已提交
436 437 438 439 440 441 442 443 444 445 446 447 448
		Path:      p,
	}

	json, err := json.Marshal(cpMeta)
	if err != nil {
		return
	}

	err = s.client.PutKey(PsCheckpoint+strconv.Itoa(s.idx), json, 3*time.Second, false)
	if err != nil {
		return
	}

H
Helin Wang 已提交
449
	return
450
}