service.go 7.2 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package pserver

import (
D
dongzhihong 已提交
18 19 20 21
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
D
dongzhihong 已提交
22
	"encoding/hex"
D
dongzhihong 已提交
23
	"encoding/json"
24 25
	"errors"
	"fmt"
26
	"io/ioutil"
D
dongzhihong 已提交
27
	"os"
D
dongzhihong 已提交
28
	"path/filepath"
D
dongzhihong 已提交
29
	"strconv"
30
	"sync"
D
dongzhihong 已提交
31 32 33
	"time"

	log "github.com/sirupsen/logrus"
34 35 36 37 38
)

// ElementType is the type of elements of a Parameter.
type ElementType int

39
// RPC error message.
40
const (
41 42 43
	AlreadyInitialized  = "pserver already initialized"
	Uninitialized       = "pserver not fully initialized"
	CheckpointMD5Failed = "checkpoint file MD5 validation failed"
44
)
45

46
// Supported element types.
47 48 49 50 51 52 53 54 55 56 57 58 59
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
60
	Content     []byte
61 62 63 64 65 66 67 68
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

69
// checkpointMeta saves checkpoint metadata
D
dongzhihong 已提交
70 71
type checkpointMeta struct {
	UUID      string `json:"uuid"`
72 73
	MD5       string `json:"md5"`
	Timestamp int64  `json:"timestamp"`
D
dongzhihong 已提交
74 75 76
}

// Checkpoint is the pserver shard persist in file
77
type Checkpoint []parameterCheckpoint
D
dongzhihong 已提交
78

79
// Gradient is the gradient of the parameter.
D
dongzhihong 已提交
80
type Gradient Parameter
81

H
Helin Wang 已提交
82
// Service is the RPC service for pserver.
83
type Service struct {
D
dongzhihong 已提交
84 85
	initialized        chan struct{}
	idx                int
D
dongzhihong 已提交
86
	checkpointInterval time.Duration
D
dongzhihong 已提交
87 88 89 90
	checkpointPath     string
	client             *EtcdClient
	mu                 sync.Mutex
	optMap             map[string]*optimizer
91 92
}

93 94 95 96 97 98 99
// parameterCheckpoint saves parameter checkpoint
type parameterCheckpoint struct {
	ParameterWithConfig
	State []byte
}

// NewCheckpointFromFile loads parameters and state from checkpoint file
D
dongzhihong 已提交
100
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
	v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
	if err != nil {
		return nil, err
	}

	var cpMeta checkpointMeta
	if err = json.Unmarshal(v, &cpMeta); err != nil {
		return nil, err
	}

	fn := filepath.Join(cpPath, cpMeta.UUID)
	if _, err = os.Stat(fn); os.IsNotExist(err) {
		return nil, err
	}
	content, err := ioutil.ReadFile(fn)
	if err != nil {
		return nil, err
	}

	h := md5.New()
	md5 := hex.EncodeToString(h.Sum(content))
	if md5 != cpMeta.MD5 {
		return nil, errors.New(CheckpointMD5Failed)
	}

	dec := gob.NewDecoder(bytes.NewReader(content))
D
dongzhihong 已提交
127
	cp := Checkpoint{}
128 129 130 131 132 133
	if err = dec.Decode(cp); err != nil {
		return nil, err
	}
	return cp, nil
}

W
wuyi05 已提交
134
// NewService creates a new service, will bypass etcd registration if no
135
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
D
dongzhihong 已提交
136
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
137
	s := &Service{
D
dongzhihong 已提交
138
		idx:                idx,
139
		checkpointInterval: interval,
D
dongzhihong 已提交
140 141
		checkpointPath:     path,
		client:             client,
142
	}
D
dongzhihong 已提交
143
	s.optMap = make(map[string]*optimizer)
144
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
145 146

	if cp != nil {
D
dongzhihong 已提交
147
		for _, item := range cp {
148 149 150 151 152
			p := ParameterWithConfig{
				Param:  item.Param,
				Config: item.Config,
			}
			s.optMap[p.Param.Name] = newOptimizer(p, item.State)
D
dongzhihong 已提交
153 154
		}
	}
W
wuyi05 已提交
155
	return s, nil
156 157
}

H
Helin Wang 已提交
158
// InitParam initializes a parameter.
159 160 161
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
162
		return errors.New(AlreadyInitialized)
163 164 165 166 167 168 169 170 171 172 173
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dongzhihong 已提交
174
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
175 176 177
	return nil
}

H
Helin Wang 已提交
178 179
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
180 181 182
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
183
		return errors.New(AlreadyInitialized)
184 185 186 187 188 189 190
	default:
	}

	close(s.initialized)
	return nil
}

191
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
192
// optimization.
193
func (s *Service) SendGrad(g Gradient, dummy *int) error {
194 195 196
	select {
	case <-s.initialized:
	default:
197
		return errors.New(Uninitialized)
198
	}
199 200

	s.mu.Lock()
H
Helin Wang 已提交
201
	defer s.mu.Unlock()
202

D
dzhwinter 已提交
203 204
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
205
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
206
	}
207

D
dongzhihong 已提交
208
	return o.UpdateParameter(g)
209 210
}

211 212
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
213 214
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
215
	defer s.mu.Unlock()
216

D
dongzhihong 已提交
217
	opt, ok := s.optMap[name]
218 219
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
220 221
	}

222 223 224 225 226 227
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
228
	// parameter content.
D
dongzhihong 已提交
229
	parameter.Name = name
D
dongzhihong 已提交
230
	parameter.ElementType = opt.elementType
231 232
	parameter.Content = opt.GetWeights()
	return nil
233 234
}

D
dongzhihong 已提交
235
// pserver save checkpoint
H
Helin Wang 已提交
236
func (s *Service) doCheckpoint() (err error) {
237
	<-s.initialized
D
dongzhihong 已提交
238 239
	s.mu.Lock()
	defer s.mu.Unlock()
D
dongzhihong 已提交
240

241
	cp := make([]parameterCheckpoint, len(s.optMap))
D
dongzhihong 已提交
242
	index := 0
D
dongzhihong 已提交
243
	for name, opt := range s.optMap {
244 245 246 247
		var pc parameterCheckpoint
		pc.Param.Name = name
		pc.Param.ElementType = opt.elementType
		pc.Param.Content = opt.GetWeights()
D
dongzhihong 已提交
248 249 250 251 252 253
		pc.State = opt.GetStates()
		cp[index] = pc
		index++
	}
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
H
Helin Wang 已提交
254
	err = encoder.Encode(cp)
D
dongzhihong 已提交
255
	if err != nil {
H
Helin Wang 已提交
256
		return
D
dongzhihong 已提交
257 258 259 260
	}

	cpMeta := checkpointMeta{}
	cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
261
	cpMeta.Timestamp = time.Now().UnixNano()
D
dongzhihong 已提交
262
	h := md5.New()
263
	cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
D
dongzhihong 已提交
264

H
Helin Wang 已提交
265 266 267 268 269
	cpMetajson, err := json.Marshal(cpMeta)
	if err != nil {
		return
	}

270
	err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
D
dongzhihong 已提交
271
	if err != nil {
H
Helin Wang 已提交
272
		return
D
dongzhihong 已提交
273 274 275 276 277
	}
	if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
		log.Info("checkpoint does not exists.")
	} else {
		err = os.Remove(cpMeta.UUID)
278 279 280 281 282
		if err != nil {
			log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
		} else {
			log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
		}
D
dongzhihong 已提交
283 284 285
	}
	f, err := os.Create(cpMeta.UUID)
	if err != nil {
H
Helin Wang 已提交
286
		return
D
dongzhihong 已提交
287
	}
H
Helin Wang 已提交
288 289 290 291 292 293 294 295 296 297 298 299 300

	defer func() {
		closeErr := f.Close()
		if closeErr != nil {
			if err != nil {
				log.Errorln(closeErr)
			} else {
				// Set closeErr as return value.
				err = closeErr
			}
		}
	}()

D
dongzhihong 已提交
301 302 303
	writer := bufio.NewWriter(f)
	_, err = writer.Write(buf.Bytes())
	if err != nil {
H
Helin Wang 已提交
304
		return
D
dongzhihong 已提交
305
	}
H
Helin Wang 已提交
306 307 308 309 310 311 312

	err = writer.Flush()
	if err != nil {
		return
	}

	return
313
}