service.go 3.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
package pserver

import (
	"errors"
	"fmt"
	"sync"
)

// ElementType is the type of elements of a Parameter.
type ElementType int

H
Helin Wang 已提交
12
var ErrAlreadyInitialized = errors.New("pserver already initialized")
13
var ErrUninitialized = errors.New("pserver not fully initialized")
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
	Content     []byte
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

// Gradient is the gradient of the parameter.
type Gradient Parameter

H
Helin Wang 已提交
41
// Service is the RPC service for pserver.
42 43 44 45 46 47 48 49
type Service struct {
	initialized chan struct{}

	mu       sync.Mutex
	opt      *optimizer
	paramMap map[string]Parameter
}

H
Helin Wang 已提交
50
// NewService creates a new service.
51
func NewService() *Service {
Q
qiaolongfei 已提交
52
	s := &Service{opt: newOptimizer(sgd, 0.005)}
53 54 55 56 57
	s.paramMap = make(map[string]Parameter)
	s.initialized = make(chan struct{})
	return s
}

H
Helin Wang 已提交
58
// InitParam initializes a parameter.
59 60 61
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
H
Helin Wang 已提交
62
		return ErrAlreadyInitialized
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
	s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param
	return nil
}

H
Helin Wang 已提交
78 79
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
80 81 82
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
H
Helin Wang 已提交
83
		return ErrAlreadyInitialized
84 85 86 87 88 89 90
	default:
	}

	close(s.initialized)
	return nil
}

91
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
92
// optimization.
93
func (s *Service) SendGrad(g Gradient, dummy *int) error {
94 95 96 97 98
	select {
	case <-s.initialized:
	default:
		return ErrUninitialized
	}
99 100

	s.mu.Lock()
H
Helin Wang 已提交
101
	defer s.mu.Unlock()
102

103 104 105
	p, ok := s.paramMap[g.Name]
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", g.Name)
106 107
	}

108
	return s.opt.UpdateParameter(p, g)
109 110
}

111 112
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
113 114
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
115
	defer s.mu.Unlock()
116

117 118 119
	p, ok := s.paramMap[name]
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
120 121
	}

122 123 124 125 126 127 128 129
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
	// paramter content.
	*parameter = p
130 131 132
	return nil
}

H
Helin Wang 已提交
133 134
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
135 136 137 138 139
	<-s.initialized

	// TODO
	return nil
}