service.go 5.2 KB
Newer Older
1 2 3
package pserver

import (
D
dongzhihong 已提交
4 5 6 7 8
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
	"encoding/hex"
9 10
	"errors"
	"fmt"
D
dongzhihong 已提交
11 12
	"os"
	"strconv"
13
	"sync"
D
dongzhihong 已提交
14 15 16
	"time"

	log "github.com/sirupsen/logrus"
17 18 19 20 21
)

// ElementType is the type of elements of a Parameter.
type ElementType int

22
const (
W
wuyi05 已提交
23
	// AlreadyInitialized is true if pserver is initialized
24
	AlreadyInitialized = "pserver already initialized"
W
wuyi05 已提交
25 26
	// Uninitialized is true if pserver not fully initialized
	Uninitialized = "pserver not fully initialized"
27
)
28

D
dongzhihong 已提交
29
const (
D
dongzhihong 已提交
30
	checkpoint_path = "./checkpoints/"
D
dongzhihong 已提交
31 32
)

33 34 35 36 37 38 39 40 41 42 43 44 45 46
// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
47
	Content     []byte
48 49 50 51 52 53
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
D
dongzhihong 已提交
54
	State  []byte // parameter training state
55 56 57 58 59
}

// Gradient is the gradient of the parameter.
type Gradient Parameter

H
Helin Wang 已提交
60
// Service is the RPC service for pserver.
61 62
type Service struct {
	initialized chan struct{}
63
	idx         int
64

65
	mu     sync.Mutex
D
dzhwinter 已提交
66
	optMap map[string]*optimizer
67 68
}

D
dongzhihong 已提交
69 70 71 72
type checkpoint struct {
	Uuid      string
	Md5sum    string
	Timestamp string
D
dongzhihong 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86
}

//serialize ParameterWithConfig to byte stream
func GetBytes(content ...interface{}) ([]byte, error) {

	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
	err := encoder.Encode(content)
	if err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

W
wuyi05 已提交
87 88
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
89 90 91 92
func NewService(idx int) (*Service, error) {
	s := &Service{
		idx: idx,
	}
D
dongzhihong 已提交
93
	s.optMap = make(map[string]*optimizer)
94
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
95 96
	gob.Register(ParameterWithConfig{})
	gob.Register(checkpoint{})
W
wuyi05 已提交
97
	return s, nil
98 99
}

H
Helin Wang 已提交
100
// InitParam initializes a parameter.
101 102 103
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
104
		return errors.New(AlreadyInitialized)
105 106 107 108 109 110 111 112 113 114 115
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dzhwinter 已提交
116
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
117 118 119
	return nil
}

H
Helin Wang 已提交
120 121
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
122 123 124
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
125
		return errors.New(AlreadyInitialized)
126 127 128 129 130 131 132
	default:
	}

	close(s.initialized)
	return nil
}

133
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
134
// optimization.
135
func (s *Service) SendGrad(g Gradient, dummy *int) error {
136 137 138
	select {
	case <-s.initialized:
	default:
139
		return errors.New(Uninitialized)
140
	}
141 142

	s.mu.Lock()
H
Helin Wang 已提交
143
	defer s.mu.Unlock()
144

D
dzhwinter 已提交
145 146
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
147
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
148
	}
149

D
dongzhihong 已提交
150
	return o.UpdateParameter(g)
151 152
}

153 154
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
155 156
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
157
	defer s.mu.Unlock()
158

D
dongzhihong 已提交
159
	opt, ok := s.optMap[name]
160 161
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
162 163
	}

164 165 166 167 168 169 170
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
	// paramter content.
D
dongzhihong 已提交
171
	parameter.Name = name
D
dongzhihong 已提交
172
	parameter.ElementType = opt.elementType
173 174
	parameter.Content = opt.GetWeights()
	return nil
175 176
}

H
Helin Wang 已提交
177 178
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
D
dongzhihong 已提交
179 180
	//FIXME: checkpoint is only used by pserver
	// and has a constant path of */checkpoints/{pserver_idx}*
181
	<-s.initialized
D
dongzhihong 已提交
182 183 184 185 186 187 188 189 190 191 192 193
	s.mu.Lock()
	defer s.mu.Unlock()
	var paramWithConfig ParameterWithConfig
	for name, opt := range s.optMap {
		paramWithConfig.Param.Name = name
		paramWithConfig.Param.ElementType = opt.elementType
		paramWithConfig.Param.Content = opt.GetWeights()
		paramWithConfig.State = opt.GetStates()
		content, err := GetBytes(paramWithConfig)
		if err != nil {
			log.Errorln(err)
		}
D
dongzhihong 已提交
194
		ck := checkpoint{}
D
dongzhihong 已提交
195
		h := md5.New()
D
dongzhihong 已提交
196 197 198
		ck.Md5sum = hex.EncodeToString(h.Sum(content))
		ck.Timestamp = time.Now().String()
		ck.Uuid = checkpoint_path + strconv.Itoa(s.idx)
D
dongzhihong 已提交
199 200 201 202
		ckbytes, err := GetBytes(ck)
		if err != nil {
			log.Errorln(err)
		}
D
dongzhihong 已提交
203 204
		// TODO: according design doc, need to save Uuid to etcd in json format
		// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx}
D
dongzhihong 已提交
205 206
		log.Infof("parameter checkpoint %s", ckbytes)

D
dongzhihong 已提交
207
		if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) {
D
dongzhihong 已提交
208 209
			log.Info("checkpoint not exists.")
		} else {
D
dongzhihong 已提交
210 211
			err = os.Remove(ck.Uuid)
			log.Infof("remove %s", ck.Uuid)
D
dongzhihong 已提交
212
		}
D
dongzhihong 已提交
213
		f, err := os.Create(ck.Uuid)
D
dongzhihong 已提交
214 215 216 217 218 219
		defer f.Close()
		if err != nil {
			log.Errorln(err)
		}
		writer := bufio.NewWriter(f)
		_, err = writer.Write(content)
D
dongzhihong 已提交
220
		writer.Flush()
D
dongzhihong 已提交
221 222
		if err != nil {
			log.Errorln(err)
D
dongzhihong 已提交
223 224
		}
	}
225 226
	return nil
}