service.go 3.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
package pserver

import (
	"errors"
	"fmt"
	"sync"
)

// ElementType is the type of elements of a Parameter.
type ElementType int

var ErrUnintialized = errors.New("pserver not initialized")
var ErrAlreadyIntialized = errors.New("pserver already initialized")

// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
	Content     []byte
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

// Gradient is the gradient of the parameter.
type Gradient Parameter

type Service struct {
	initialized chan struct{}

	mu       sync.Mutex
	opt      *optimizer
	paramMap map[string]Parameter
}

func NewService() *Service {
	s := &Service{}
	s.paramMap = make(map[string]Parameter)
	s.initialized = make(chan struct{})
	return s
}

func (s *Service) BeginInitParams(config []byte, dummy *int) error {
	select {
	case <-s.initialized:
		return ErrAlreadyIntialized
	default:
	}

	s.mu.Lock()
	defer s.mu.Unlock()

	if s.opt != nil {
		s.opt.Cleanup()
	}

	// TODO(helin): parse learning rate from config
	s.opt = newOptimizer(sgd, 0.01)
	return nil
}

func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
		return ErrAlreadyIntialized
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
	s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param
	return nil
}

func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
		return ErrAlreadyIntialized
	default:
	}

	close(s.initialized)
	return nil
}

func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
	select {
	case <-s.initialized:
	default:
		return ErrUnintialized
	}

112 113 114 115 116
	count := len(grads)
	if count == 0 {
		return nil
	}

117
	s.mu.Lock()
H
Helin Wang 已提交
118
	defer s.mu.Unlock()
119 120 121 122 123 124 125

	for _, g := range grads {
		if _, ok := s.paramMap[g.Name]; !ok {
			return fmt.Errorf("parameter: %s does not exist", g.Name)
		}
	}

126
	errCh := make(chan error, count)
127 128
	for _, g := range grads {
		go func(p Parameter, g Gradient) {
129 130
			err := s.opt.UpdateParameter(p, g)
			errCh <- err
131 132 133
		}(s.paramMap[g.Name], g)
	}

134 135 136 137 138 139 140 141 142 143 144
	recv := 0
	for err := range errCh {
		if err != nil {
			return err
		}

		recv++
		if recv == count {
			break
		}
	}
145 146 147 148 149 150
	return nil
}

func (s *Service) GetParams(names []string, parameters *[]Parameter) error {
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
151
	defer s.mu.Unlock()
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179

	for _, n := range names {
		if _, ok := s.paramMap[n]; !ok {
			return fmt.Errorf("parameter: %s does not exist", n)
		}
	}

	*parameters = make([]Parameter, len(names))
	for i, n := range names {
		// The parameter content (a byte slice) may change
		// during RPC serialization due to write from other
		// goroutine, we allow it since mini-batch based deep
		// learning optimization methods are stochastic in
		// nature. This race condition is allowed deliberately
		// to save the program from making a copy of the
		// paramter content.
		(*parameters)[i] = s.paramMap[n]
	}

	return nil
}

func (s *Service) SaveModel(path string, dummy *int) error {
	<-s.initialized

	// TODO
	return nil
}