service.go 3.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
package pserver

import (
	"errors"
	"fmt"
	"sync"
)

// ElementType is the type of elements of a Parameter.
type ElementType int

H
Helin Wang 已提交
12
var ErrAlreadyInitialized = errors.New("pserver already initialized")
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
	Content     []byte
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

// Gradient is the gradient of the parameter.
type Gradient Parameter

type Service struct {
	initialized chan struct{}

	mu       sync.Mutex
	opt      *optimizer
	paramMap map[string]Parameter
}

func NewService() *Service {
	s := &Service{}
	s.paramMap = make(map[string]Parameter)
	s.initialized = make(chan struct{})
	return s
}

func (s *Service) BeginInitParams(config []byte, dummy *int) error {
	select {
	case <-s.initialized:
H
Helin Wang 已提交
58
		return ErrAlreadyInitialized
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
	default:
	}

	s.mu.Lock()
	defer s.mu.Unlock()

	if s.opt != nil {
		s.opt.Cleanup()
	}

	// TODO(helin): parse learning rate from config
	s.opt = newOptimizer(sgd, 0.01)
	return nil
}

func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
H
Helin Wang 已提交
77
		return ErrAlreadyInitialized
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
	s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param
	return nil
}

func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
H
Helin Wang 已提交
96
		return ErrAlreadyInitialized
97 98 99 100 101 102 103 104
	default:
	}

	close(s.initialized)
	return nil
}

func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
H
Helin Wang 已提交
105
	<-s.initialized
106

107 108 109 110 111
	count := len(grads)
	if count == 0 {
		return nil
	}

112
	s.mu.Lock()
H
Helin Wang 已提交
113
	defer s.mu.Unlock()
114 115 116 117 118 119 120

	for _, g := range grads {
		if _, ok := s.paramMap[g.Name]; !ok {
			return fmt.Errorf("parameter: %s does not exist", g.Name)
		}
	}

121
	errCh := make(chan error, count)
122 123
	for _, g := range grads {
		go func(p Parameter, g Gradient) {
124 125
			err := s.opt.UpdateParameter(p, g)
			errCh <- err
126 127 128
		}(s.paramMap[g.Name], g)
	}

129 130 131 132 133 134 135 136 137 138 139
	recv := 0
	for err := range errCh {
		if err != nil {
			return err
		}

		recv++
		if recv == count {
			break
		}
	}
140 141 142 143 144 145
	return nil
}

func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
146
	defer s.mu.Unlock()
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

	for _, n := range names {
		if _, ok := s.paramMap[n]; !ok {
			return fmt.Errorf("parameter: %s does not exist", n)
		}
	}

	*parameters = make([]Parameter, len(names))
	for i, n := range names {
		// The parameter content (a byte slice) may change
		// during RPC serialization due to write from other
		// goroutine, we allow it since mini-batch based deep
		// learning optimization methods are stochastic in
		// nature. This race condition is allowed deliberately
		// to save the program from making a copy of the
		// paramter content.
		(*parameters)[i] = s.paramMap[n]
	}

	return nil
}

func (s *Service) SaveModel(path string, dummy *int) error {
	<-s.initialized

	// TODO
	return nil
}