service.go 5.5 KB
Newer Older
1 2 3
package pserver

import (
D
dongzhihong 已提交
4 5 6 7
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
D
dongzhihong 已提交
8
	"encoding/hex"
D
dongzhihong 已提交
9
	"encoding/json"
10 11
	"errors"
	"fmt"
D
dongzhihong 已提交
12
	"os"
D
dongzhihong 已提交
13
	"path/filepath"
D
dongzhihong 已提交
14
	"strconv"
15
	"sync"
D
dongzhihong 已提交
16 17 18
	"time"

	log "github.com/sirupsen/logrus"
19 20 21 22 23
)

// ElementType is the type of elements of a Parameter.
type ElementType int

24
const (
W
wuyi05 已提交
25
	// AlreadyInitialized is true if pserver is initialized
26
	AlreadyInitialized = "pserver already initialized"
W
wuyi05 已提交
27 28
	// Uninitialized is true if pserver not fully initialized
	Uninitialized = "pserver not fully initialized"
29
)
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44

// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
45
	Content     []byte
46 47 48 49 50 51 52 53
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

D
dongzhihong 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
// Checkpoint of Parameter and State
type parameterCheckPoint struct {
	ParamConfig ParameterWithConfig
	State       []byte
}

// checkpoint signature
type checkpointMeta struct {
	UUID      string `json:"uuid"`
	Md5sum    string `json:"md5sum"`
	Timestamp string `json:"timestamp"`
}

// Checkpoint is the pserver shard persist in file
type Checkpoint []parameterCheckPoint

70
// Gradient is the gradient of the parameter.
D
dongzhihong 已提交
71
type Gradient Parameter
72

H
Helin Wang 已提交
73
// Service is the RPC service for pserver.
74
type Service struct {
D
dongzhihong 已提交
75 76
	initialized        chan struct{}
	idx                int
D
dongzhihong 已提交
77
	checkpointInterval time.Duration
D
dongzhihong 已提交
78 79 80 81
	checkpointPath     string
	client             *EtcdClient
	mu                 sync.Mutex
	optMap             map[string]*optimizer
82 83
}

W
wuyi05 已提交
84 85
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
D
dongzhihong 已提交
86
func NewService(idx int, seconds int, path string, client *EtcdClient, cp Checkpoint) (*Service, error) {
87
	s := &Service{
D
dongzhihong 已提交
88 89 90 91
		idx:                idx,
		checkpointInterval: time.Second * time.Duration(seconds),
		checkpointPath:     path,
		client:             client,
92
	}
D
dongzhihong 已提交
93
	s.optMap = make(map[string]*optimizer)
94
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
95 96 97 98 99 100 101 102

	if cp != nil {
		for _, item := range cp {
			p := item.ParamConfig
			st := item.State
			s.optMap[p.Param.Name] = newOptimizer(p, st)
		}
	}
W
wuyi05 已提交
103
	return s, nil
104 105
}

H
Helin Wang 已提交
106
// InitParam initializes a parameter.
107 108 109
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
110
		return errors.New(AlreadyInitialized)
111 112 113 114 115 116 117 118 119 120 121
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dongzhihong 已提交
122
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs, nil)
123 124 125
	return nil
}

H
Helin Wang 已提交
126 127
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
128 129 130
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
131
		return errors.New(AlreadyInitialized)
132 133 134 135 136 137 138
	default:
	}

	close(s.initialized)
	return nil
}

139
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
140
// optimization.
141
func (s *Service) SendGrad(g Gradient, dummy *int) error {
142 143 144
	select {
	case <-s.initialized:
	default:
145
		return errors.New(Uninitialized)
146
	}
147 148

	s.mu.Lock()
H
Helin Wang 已提交
149
	defer s.mu.Unlock()
150

D
dzhwinter 已提交
151 152
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
153
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
154
	}
155

D
dongzhihong 已提交
156
	return o.UpdateParameter(g)
157 158
}

159 160
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
161 162
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
163
	defer s.mu.Unlock()
164

D
dongzhihong 已提交
165
	opt, ok := s.optMap[name]
166 167
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
168 169
	}

170 171 172 173 174 175 176
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
	// paramter content.
D
dongzhihong 已提交
177
	parameter.Name = name
D
dongzhihong 已提交
178
	parameter.ElementType = opt.elementType
179 180
	parameter.Content = opt.GetWeights()
	return nil
181 182
}

D
dongzhihong 已提交
183 184
// pserver save checkpoint
func (s *Service) doCheckpoint() error {
185
	<-s.initialized
D
dongzhihong 已提交
186 187
	s.mu.Lock()
	defer s.mu.Unlock()
D
dongzhihong 已提交
188 189 190

	cp := make([]parameterCheckPoint, 0, len(s.optMap))
	index := 0
D
dongzhihong 已提交
191
	for name, opt := range s.optMap {
D
dongzhihong 已提交
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
		var pc parameterCheckPoint
		pc.ParamConfig.Param.Name = name
		pc.ParamConfig.Param.ElementType = opt.elementType
		pc.ParamConfig.Param.Content = opt.GetWeights()
		pc.State = opt.GetStates()
		cp[index] = pc
		index++
	}
	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
	err := encoder.Encode(cp)
	if err != nil {
		return err
	}

	cpMeta := checkpointMeta{}
	cpMeta.UUID = s.checkpointPath + strconv.Itoa(s.idx)
	cpMeta.Timestamp = time.Now().String()
	h := md5.New()
D
dongzhihong 已提交
211
	cpMeta.Md5sum = hex.EncodeToString(h.Sum(buf.Bytes()))
D
dongzhihong 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233

	cpMetajson, err := json.Marshal(cpMeta)
	s.client.PutKey(filepath.Join(PsCheckpoint, strconv.Itoa(s.idx)), cpMetajson, 3)
	if err != nil {
		return err
	}
	if _, err = os.Stat(cpMeta.UUID); os.IsNotExist(err) {
		log.Info("checkpoint does not exists.")
	} else {
		err = os.Remove(cpMeta.UUID)
		log.Infof("checkpoint %s already exsits, removing ", cpMeta.UUID)
	}
	f, err := os.Create(cpMeta.UUID)
	defer f.Close()
	if err != nil {
		log.Errorln(err)
	}
	writer := bufio.NewWriter(f)
	_, err = writer.Write(buf.Bytes())
	writer.Flush()
	if err != nil {
		log.Errorln(err)
D
dongzhihong 已提交
234
	}
235 236
	return nil
}