service.go 5.1 KB
Newer Older
1 2 3
package pserver

import (
D
dongzhihong 已提交
4 5 6 7 8
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
	"encoding/hex"
9 10
	"errors"
	"fmt"
D
dongzhihong 已提交
11 12
	"os"
	"strconv"
13
	"sync"
D
dongzhihong 已提交
14 15 16
	"time"

	log "github.com/sirupsen/logrus"
17 18 19 20 21
)

// ElementType is the type of elements of a Parameter.
type ElementType int

22 23 24 25
const (
	AlreadyInitialized = "pserver already initialized"
	Uninitialized      = "pserver not fully initialized"
)
26

D
dongzhihong 已提交
27 28 29 30
const (
	checkpoint_path = "/checkpoints/"
)

31 32 33 34 35 36 37 38 39 40
// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

W
wuyi05 已提交
41 42 43
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"

44 45 46 47
// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
48
	Content     []byte
49 50 51 52 53 54
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
D
dongzhihong 已提交
55
	State  []byte // parameter training state
56 57 58 59 60
}

// Gradient is the gradient of the parameter.
type Gradient Parameter

H
Helin Wang 已提交
61
// Service is the RPC service for pserver.
62 63
type Service struct {
	initialized chan struct{}
64
	idx         int
65

66
	mu     sync.Mutex
D
dzhwinter 已提交
67
	optMap map[string]*optimizer
68 69
}

D
dongzhihong 已提交
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
type Checkpoint struct {
	uuid      string
	md5sum    string
	timestamp string
}

//serialize ParameterWithConfig to byte stream
func GetBytes(content ...interface{}) ([]byte, error) {

	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
	err := encoder.Encode(content)
	if err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

W
wuyi05 已提交
88 89
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
90 91 92 93
func NewService(idx int) (*Service, error) {
	s := &Service{
		idx: idx,
	}
D
dongzhihong 已提交
94
	s.optMap = make(map[string]*optimizer)
95
	s.initialized = make(chan struct{})
W
wuyi05 已提交
96
	return s, nil
97 98
}

H
Helin Wang 已提交
99
// InitParam initializes a parameter.
100 101 102
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
103
		return errors.New(AlreadyInitialized)
104 105 106 107 108 109 110 111 112 113 114
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dzhwinter 已提交
115
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
116 117 118
	return nil
}

H
Helin Wang 已提交
119 120
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
121 122 123
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
124
		return errors.New(AlreadyInitialized)
125 126 127 128 129 130 131
	default:
	}

	close(s.initialized)
	return nil
}

132
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
133
// optimization.
134
func (s *Service) SendGrad(g Gradient, dummy *int) error {
135 136 137
	select {
	case <-s.initialized:
	default:
138
		return errors.New(Uninitialized)
139
	}
140 141

	s.mu.Lock()
H
Helin Wang 已提交
142
	defer s.mu.Unlock()
143

D
dzhwinter 已提交
144 145
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
146
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
147
	}
148

D
dongzhihong 已提交
149
	return o.UpdateParameter(g)
150 151
}

152 153
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
154 155
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
156
	defer s.mu.Unlock()
157

D
dongzhihong 已提交
158
	opt, ok := s.optMap[name]
159 160
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
161 162
	}

163 164 165 166 167 168 169
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
	// paramter content.
D
dongzhihong 已提交
170
	parameter.Name = name
D
dongzhihong 已提交
171
	parameter.ElementType = opt.elementType
172 173
	parameter.Content = opt.GetWeights()
	return nil
174 175
}

H
Helin Wang 已提交
176 177
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
D
dongzhihong 已提交
178 179
	//FIXME: checkpoint is only used by pserver
	// and has a constant path of */checkpoints/{pserver_idx}*
180
	<-s.initialized
D
dongzhihong 已提交
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
	s.mu.Lock()
	defer s.mu.Unlock()
	var paramWithConfig ParameterWithConfig
	for name, opt := range s.optMap {
		paramWithConfig.Param.Name = name
		paramWithConfig.Param.ElementType = opt.elementType
		paramWithConfig.Param.Content = opt.GetWeights()
		paramWithConfig.State = opt.GetStates()
		content, err := GetBytes(paramWithConfig)
		if err != nil {
			log.Errorln(err)
		}
		ck := Checkpoint{}
		h := md5.New()
		ck.md5sum = hex.EncodeToString(h.Sum(content))
		ck.timestamp = time.Now().String()
		ck.uuid = checkpoint_path + strconv.Itoa(s.idx)
		ckbytes, err := GetBytes(ck)
		if err != nil {
			log.Errorln(err)
		}
		// TODO: according design doc, need to save uuid to etcd in json format
		// {\"uuid\": [UUID], \"md5\", \"MD5 sum\", \"timestamp\": xxxx}
		log.Infof("parameter checkpoint %s", ckbytes)

		if _, err = os.Stat(ck.uuid); os.IsNotExist(err) {
			log.Info("checkpoint not exists.")
		} else {
			err = os.Remove(ck.uuid)
			log.Infof("remove %s", ck.uuid)
		}
		f, err := os.Create(ck.uuid)
		defer f.Close()
		if err != nil {
			log.Errorln(err)
		}
		writer := bufio.NewWriter(f)
		_, err = writer.Write(content)
		if err != nil {
			log.Errorln(err)
D
dongzhihong 已提交
221 222
		}
	}
223 224
	return nil
}