service.go 7.4 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package pserver

import (
D
dongzhihong 已提交
18 19 20 21
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
D
dongzhihong 已提交
22
	"encoding/hex"
D
dongzhihong 已提交
23
	"encoding/json"
24 25
	"errors"
	"fmt"
26
	"io/ioutil"
D
dongzhihong 已提交
27
	"os"
D
dongzhihong 已提交
28
	"path/filepath"
D
dongzhihong 已提交
29
	"strconv"
30
	"sync"
D
dongzhihong 已提交
31 32 33
	"time"

	log "github.com/sirupsen/logrus"
34 35 36 37 38
)

// ElementType is the type of elements of a Parameter.
type ElementType int

39 40 41 42
// ErrCheckpointNotFound indicates that the pserver checkpoint could
// not be found.
var ErrCheckpointNotFound = errors.New("checkpoint not found")

43
// RPC error message.
44
const (
45 46 47
	AlreadyInitialized  = "pserver already initialized"
	Uninitialized       = "pserver not fully initialized"
	CheckpointMD5Failed = "checkpoint file MD5 validation failed"
48
)
49

50
// Supported element types.
51 52 53 54 55 56 57 58 59 60 61 62 63
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
64
	Content     []byte
65 66 67 68 69 70 71 72
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

73
// checkpointMeta saves checkpoint metadata
D
dongzhihong 已提交
74 75
type checkpointMeta struct {
	UUID      string `json:"uuid"`
76 77
	MD5       string `json:"md5"`
	Timestamp int64  `json:"timestamp"`
D
dongzhihong 已提交
78 79 80
}

// Checkpoint is the pserver shard persist in file
81
type Checkpoint []parameterCheckpoint
D
dongzhihong 已提交
82

83
// Gradient is the gradient of the parameter.
D
dongzhihong 已提交
84
type Gradient Parameter
85

H
Helin Wang 已提交
86
// Service is the RPC service for pserver.
87
type Service struct {
D
dongzhihong 已提交
88 89
	initialized        chan struct{}
	idx                int
D
dongzhihong 已提交
90
	checkpointInterval time.Duration
D
dongzhihong 已提交
91 92 93 94
	checkpointPath     string
	client             *EtcdClient
	mu                 sync.Mutex
	optMap             map[string]*optimizer
95 96
}

97 98 99 100 101 102 103
// parameterCheckpoint saves parameter checkpoint
type parameterCheckpoint struct {
	ParameterWithConfig
	State []byte
}

// NewCheckpointFromFile loads parameters and state from checkpoint file
D
dongzhihong 已提交
104
func NewCheckpointFromFile(cpPath string, idx int, e *EtcdClient) (Checkpoint, error) {
105 106 107 108 109
	v, err := e.GetKey(PsPath+string(idx), 3*time.Second)
	if err != nil {
		return nil, err
	}

110 111 112 113
	if len(v) == 0 {
		return nil, ErrCheckpointNotFound
	}

114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
	var cpMeta checkpointMeta
	if err = json.Unmarshal(v, &cpMeta); err != nil {
		return nil, err
	}

	fn := filepath.Join(cpPath, cpMeta.UUID)
	if _, err = os.Stat(fn); os.IsNotExist(err) {
		return nil, err
	}
	content, err := ioutil.ReadFile(fn)
	if err != nil {
		return nil, err
	}

	h := md5.New()
	md5 := hex.EncodeToString(h.Sum(content))
	if md5 != cpMeta.MD5 {
		return nil, errors.New(CheckpointMD5Failed)
	}

	dec := gob.NewDecoder(bytes.NewReader(content))
D
dongzhihong 已提交
135
	cp := Checkpoint{}
136 137 138 139 140 141
	if err = dec.Decode(cp); err != nil {
		return nil, err
	}
	return cp, nil
}

W
wuyi05 已提交
142
// NewService creates a new service, will bypass etcd registration if no
143
// endpoints specified. It will recovery from checkpoint file if a exists a specified checkpoint.
D
dongzhihong 已提交
144
func NewService(idx int, interval time.Duration, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
145
	s := &Service{
D
dongzhihong 已提交
146
		idx:                idx,
147
		checkpointInterval: interval,
D
dongzhihong 已提交
148 149
		checkpointPath:     path,
		client:             client,
150
	}
D
dongzhihong 已提交
151
	s.optMap = make(map[string]*optimizer)
152
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
153 154

	if cp != nil {
D
dongzhihong 已提交
155
		for _, item := range cp {
156 157 158 159 160
			p := ParameterWithConfig{
				Param:  item.Param,
				Config: item.Config,
			}
			s.optMap[p.Param.Name] = newOptimizer(p, item.State)
D
dongzhihong 已提交
161 162
		}
	}
W
wuyi05 已提交
163
	return s, nil
164 165
}

H
Helin Wang 已提交
166
// InitParam initializes a parameter.
167 168 169
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
170
		return errors.New(AlreadyInitialized)
171 172 173 174 175 176 177 178 179 180 181
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dongzhihong 已提交
182
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
183 184 185
	return nil
}

H
Helin Wang 已提交
186 187
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
188 189 190
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
191
		return errors.New(AlreadyInitialized)
192 193 194 195 196 197 198
	default:
	}

	close(s.initialized)
	return nil
}

199
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
200
// optimization.
201
func (s *Service) SendGrad(g Gradient, dummy *int) error {
202 203 204
	select {
	case <-s.initialized:
	default:
205
		return errors.New(Uninitialized)
206
	}
207 208

	s.mu.Lock()
H
Helin Wang 已提交
209
	defer s.mu.Unlock()
210

D
dzhwinter 已提交
211 212
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
213
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
214
	}
215

D
dongzhihong 已提交
216
	return o.UpdateParameter(g)
217 218
}

219 220
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
221 222
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
223
	defer s.mu.Unlock()
224

D
dongzhihong 已提交
225
	opt, ok := s.optMap[name]
226 227
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
228 229
	}

230 231 232 233 234 235
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
236
	// parameter content.
D
dongzhihong 已提交
237
	parameter.Name = name
D
dongzhihong 已提交
238
	parameter.ElementType = opt.elementType
239 240
	parameter.Content = opt.GetWeights()
	return nil
241 242
}

D
dongzhihong 已提交
243
// pserver save checkpoint
H
Helin Wang 已提交
244
func (s *Service) doCheckpoint() (err error) {
245
	<-s.initialized
D
dongzhihong 已提交
246 247
	s.mu.Lock()
	defer s.mu.Unlock()
D
dongzhihong 已提交
248

249
	cp := make([]parameterCheckpoint, len(s.optMap))
D
dongzhihong 已提交
250
	index := 0
D
dongzhihong 已提交
251
	for name, opt := range s.optMap {
252 253 254 255
		var pc parameterCheckpoint
		pc.Param.Name = name
		pc.Param.ElementType = opt.elementType
		pc.Param.Content = opt.GetWeights()
D
dongzhihong 已提交
256 257 258 259 260 261
		pc.State = opt.GetStates()
		cp[index] = pc
		index++
	}
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
H
Helin Wang 已提交
262
	err = encoder.Encode(cp)
D
dongzhihong 已提交
263
	if err != nil {
H
Helin Wang 已提交
264
		return
D
dongzhihong 已提交
265 266 267 268
	}

	cpMeta := checkpointMeta{}
	cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
269
	cpMeta.Timestamp = time.Now().UnixNano()
D
dongzhihong 已提交
270
	h := md5.New()
271
	cpMeta.MD5 = hex.EncodeToString(h.Sum(buf.Bytes()))
D
dongzhihong 已提交
272

H
Helin Wang 已提交
273 274 275 276 277
	cpMetajson, err := json.Marshal(cpMeta)
	if err != nil {
		return
	}

278
	err = s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3*time.Second)
D
dongzhihong 已提交
279
	if err != nil {
H
Helin Wang 已提交
280
		return
D
dongzhihong 已提交
281 282 283 284 285
	}
	if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
		log.Info("checkpoint does not exists.")
	} else {
		err = os.Remove(cpMeta.UUID)
286 287 288 289 290
		if err != nil {
			log.Infof("Removing checkpoint %s failed", cpMeta.UUID)
		} else {
			log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
		}
D
dongzhihong 已提交
291 292 293
	}
	f, err := os.Create(cpMeta.UUID)
	if err != nil {
H
Helin Wang 已提交
294
		return
D
dongzhihong 已提交
295
	}
H
Helin Wang 已提交
296 297 298 299 300 301 302 303 304 305 306 307 308

	defer func() {
		closeErr := f.Close()
		if closeErr != nil {
			if err != nil {
				log.Errorln(closeErr)
			} else {
				// Set closeErr as return value.
				err = closeErr
			}
		}
	}()

D
dongzhihong 已提交
309 310 311
	writer := bufio.NewWriter(f)
	_, err = writer.Write(buf.Bytes())
	if err != nil {
H
Helin Wang 已提交
312
		return
D
dongzhihong 已提交
313
	}
H
Helin Wang 已提交
314 315 316 317 318 319 320

	err = writer.Flush()
	if err != nil {
		return
	}

	return
321
}