service.go 5.2 KB
Newer Older
1 2 3
package pserver

import (
D
dongzhihong 已提交
4 5 6 7 8
	"bufio"
	"bytes"
	"crypto/md5"
	"encoding/gob"
	"encoding/hex"
9 10
	"errors"
	"fmt"
D
dongzhihong 已提交
11 12
	"os"
	"strconv"
13
	"sync"
D
dongzhihong 已提交
14 15 16
	"time"

	log "github.com/sirupsen/logrus"
17 18 19 20 21
)

// ElementType is the type of elements of a Parameter.
type ElementType int

22 23 24 25
const (
	AlreadyInitialized = "pserver already initialized"
	Uninitialized      = "pserver not fully initialized"
)
26

D
dongzhihong 已提交
27
const (
D
dongzhihong 已提交
28
	checkpoint_path = "./checkpoints/"
D
dongzhihong 已提交
29 30
)

31 32 33 34 35 36 37 38 39 40
// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

W
wuyi05 已提交
41 42 43
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"

44 45 46 47
// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
D
dzhwinter 已提交
48
	Content     []byte
49 50 51 52 53 54
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
D
dongzhihong 已提交
55
	State  []byte // parameter training state
56 57 58 59 60
}

// Gradient is the gradient of the parameter.
type Gradient Parameter

H
Helin Wang 已提交
61
// Service is the RPC service for pserver.
62 63
type Service struct {
	initialized chan struct{}
64
	idx         int
65

66
	mu     sync.Mutex
D
dzhwinter 已提交
67
	optMap map[string]*optimizer
68 69
}

D
dongzhihong 已提交
70 71 72 73
type checkpoint struct {
	Uuid      string
	Md5sum    string
	Timestamp string
D
dongzhihong 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86 87
}

//serialize ParameterWithConfig to byte stream
func GetBytes(content ...interface{}) ([]byte, error) {

	var buf bytes.Buffer
	encoder := gob.NewEncoder(&buf)
	err := encoder.Encode(content)
	if err != nil {
		return nil, err
	}
	return buf.Bytes(), nil
}

W
wuyi05 已提交
88 89
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
90 91 92 93
func NewService(idx int) (*Service, error) {
	s := &Service{
		idx: idx,
	}
D
dongzhihong 已提交
94
	s.optMap = make(map[string]*optimizer)
95
	s.initialized = make(chan struct{})
D
dongzhihong 已提交
96 97
	gob.Register(ParameterWithConfig{})
	gob.Register(checkpoint{})
W
wuyi05 已提交
98
	return s, nil
99 100
}

H
Helin Wang 已提交
101
// InitParam initializes a parameter.
102 103 104
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
105
		return errors.New(AlreadyInitialized)
106 107 108 109 110 111 112 113 114 115 116
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
D
dzhwinter 已提交
117
	s.optMap[paramWithConfigs.Param.Name] = newOptimizer(paramWithConfigs)
118 119 120
	return nil
}

H
Helin Wang 已提交
121 122
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
123 124 125
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
126
		return errors.New(AlreadyInitialized)
127 128 129 130 131 132 133
	default:
	}

	close(s.initialized)
	return nil
}

134
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
135
// optimization.
136
func (s *Service) SendGrad(g Gradient, dummy *int) error {
137 138 139
	select {
	case <-s.initialized:
	default:
140
		return errors.New(Uninitialized)
141
	}
142 143

	s.mu.Lock()
H
Helin Wang 已提交
144
	defer s.mu.Unlock()
145

D
dzhwinter 已提交
146 147
	o, ok := s.optMap[g.Name]
	if !ok {
D
dzhwinter 已提交
148
		return fmt.Errorf("parameter: %s does not exist", g.Name)
D
dzhwinter 已提交
149
	}
150

D
dongzhihong 已提交
151
	return o.UpdateParameter(g)
152 153
}

154 155
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
156 157
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
158
	defer s.mu.Unlock()
159

D
dongzhihong 已提交
160
	opt, ok := s.optMap[name]
161 162
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
163 164
	}

165 166 167 168 169 170 171
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
	// paramter content.
D
dongzhihong 已提交
172
	parameter.Name = name
D
dongzhihong 已提交
173
	parameter.ElementType = opt.elementType
174 175
	parameter.Content = opt.GetWeights()
	return nil
176 177
}

H
Helin Wang 已提交
178 179
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
D
dongzhihong 已提交
180 181
	//FIXME: checkpoint is only used by pserver
	// and has a constant path of */checkpoints/{pserver_idx}*
182
	<-s.initialized
D
dongzhihong 已提交
183 184 185 186 187 188 189 190 191 192 193 194
	s.mu.Lock()
	defer s.mu.Unlock()
	var paramWithConfig ParameterWithConfig
	for name, opt := range s.optMap {
		paramWithConfig.Param.Name = name
		paramWithConfig.Param.ElementType = opt.elementType
		paramWithConfig.Param.Content = opt.GetWeights()
		paramWithConfig.State = opt.GetStates()
		content, err := GetBytes(paramWithConfig)
		if err != nil {
			log.Errorln(err)
		}
D
dongzhihong 已提交
195
		ck := checkpoint{}
D
dongzhihong 已提交
196
		h := md5.New()
D
dongzhihong 已提交
197 198 199
		ck.Md5sum = hex.EncodeToString(h.Sum(content))
		ck.Timestamp = time.Now().String()
		ck.Uuid = checkpoint_path + strconv.Itoa(s.idx)
D
dongzhihong 已提交
200 201 202 203
		ckbytes, err := GetBytes(ck)
		if err != nil {
			log.Errorln(err)
		}
D
dongzhihong 已提交
204 205
		// TODO: according design doc, need to save Uuid to etcd in json format
		// {\"Uuid\": [UUID], \"md5\", \"MD5 sum\", \"Timestamp\": xxxx}
D
dongzhihong 已提交
206 207
		log.Infof("parameter checkpoint %s", ckbytes)

D
dongzhihong 已提交
208
		if _, err = os.Stat(ck.Uuid); os.IsNotExist(err) {
D
dongzhihong 已提交
209 210
			log.Info("checkpoint not exists.")
		} else {
D
dongzhihong 已提交
211 212
			err = os.Remove(ck.Uuid)
			log.Infof("remove %s", ck.Uuid)
D
dongzhihong 已提交
213
		}
D
dongzhihong 已提交
214
		f, err := os.Create(ck.Uuid)
D
dongzhihong 已提交
215 216 217 218 219 220
		defer f.Close()
		if err != nil {
			log.Errorln(err)
		}
		writer := bufio.NewWriter(f)
		_, err = writer.Write(content)
D
dongzhihong 已提交
221
		writer.Flush()
D
dongzhihong 已提交
222 223
		if err != nil {
			log.Errorln(err)
D
dongzhihong 已提交
224 225
		}
	}
226 227
	return nil
}