service.go 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
package pserver

import (
	"errors"
	"fmt"
	"sync"
)

// ElementType is the type of elements of a Parameter.
type ElementType int

H
Helin Wang 已提交
12
var ErrAlreadyInitialized = errors.New("pserver already initialized")
13
var ErrUninitialized = errors.New("pserver not fully initialized")
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
	Content     []byte
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

// Gradient is the gradient of the parameter.
type Gradient Parameter

H
Helin Wang 已提交
41
// Service is the RPC service for pserver.
42 43 44 45 46 47 48 49
type Service struct {
	initialized chan struct{}

	mu       sync.Mutex
	opt      *optimizer
	paramMap map[string]Parameter
}

H
Helin Wang 已提交
50
// NewService creates a new service.
51 52 53 54 55 56 57
func NewService() *Service {
	s := &Service{}
	s.paramMap = make(map[string]Parameter)
	s.initialized = make(chan struct{})
	return s
}

H
Helin Wang 已提交
58 59
// BeginInitParams tells the parameter server that the parameter
// initialization has begun.
60 61 62
func (s *Service) BeginInitParams(config []byte, dummy *int) error {
	select {
	case <-s.initialized:
H
Helin Wang 已提交
63
		return ErrAlreadyInitialized
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
	default:
	}

	s.mu.Lock()
	defer s.mu.Unlock()

	if s.opt != nil {
		s.opt.Cleanup()
	}

	// TODO(helin): parse learning rate from config
	s.opt = newOptimizer(sgd, 0.01)
	return nil
}

H
Helin Wang 已提交
79
// InitParam initializes a parameter.
80 81 82
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
H
Helin Wang 已提交
83
		return ErrAlreadyInitialized
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
	s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param
	return nil
}

H
Helin Wang 已提交
99 100
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
101 102 103
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
H
Helin Wang 已提交
104
		return ErrAlreadyInitialized
105 106 107 108 109 110 111
	default:
	}

	close(s.initialized)
	return nil
}

H
Helin Wang 已提交
112 113
// SendGrads sends gradients to parameter servers for parameter
// optimization.
114
func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
115 116 117 118 119
	select {
	case <-s.initialized:
	default:
		return ErrUninitialized
	}
120

121 122 123 124 125
	count := len(grads)
	if count == 0 {
		return nil
	}

126
	s.mu.Lock()
H
Helin Wang 已提交
127
	defer s.mu.Unlock()
128 129 130 131 132 133 134

	for _, g := range grads {
		if _, ok := s.paramMap[g.Name]; !ok {
			return fmt.Errorf("parameter: %s does not exist", g.Name)
		}
	}

135
	errCh := make(chan error, count)
136 137
	for _, g := range grads {
		go func(p Parameter, g Gradient) {
138 139
			err := s.opt.UpdateParameter(p, g)
			errCh <- err
140 141 142
		}(s.paramMap[g.Name], g)
	}

143 144 145 146 147 148 149 150 151 152 153
	recv := 0
	for err := range errCh {
		if err != nil {
			return err
		}

		recv++
		if recv == count {
			break
		}
	}
154 155 156
	return nil
}

H
Helin Wang 已提交
157
// GetParams gets parameters from the parameter server.
158 159 160
func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
161
	defer s.mu.Unlock()
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183

	for _, n := range names {
		if _, ok := s.paramMap[n]; !ok {
			return fmt.Errorf("parameter: %s does not exist", n)
		}
	}

	*parameters = make([]Parameter, len(names))
	for i, n := range names {
		// The parameter content (a byte slice) may change
		// during RPC serialization due to write from other
		// goroutine, we allow it since mini-batch based deep
		// learning optimization methods are stochastic in
		// nature. This race condition is allowed deliberately
		// to save the program from making a copy of the
		// paramter content.
		(*parameters)[i] = s.paramMap[n]
	}

	return nil
}

H
Helin Wang 已提交
184 185
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
186 187 188 189 190
	<-s.initialized

	// TODO
	return nil
}