service.go 6.1 KB
Newer Older
1 2 3
package pserver

import (
W
wuyi05 已提交
4
	"context"
5 6
	"errors"
	"fmt"
W
wuyi05 已提交
7 8
	"strconv"
	"strings"
9
	"sync"
W
wuyi05 已提交
10 11 12 13 14 15
	"time"

	"github.com/PaddlePaddle/Paddle/go/utils"
	"github.com/coreos/etcd/clientv3"
	"github.com/coreos/etcd/clientv3/concurrency"
	log "github.com/sirupsen/logrus"
16 17 18 19 20
)

// ElementType is the type of elements of a Parameter.
type ElementType int

21 22 23 24
const (
	AlreadyInitialized = "pserver already initialized"
	Uninitialized      = "pserver not fully initialized"
)
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

// Supported element types
const (
	Int32 ElementType = iota
	UInt32
	Int64
	UInt64
	Float32
	Float64
)

// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
	Name        string
	ElementType ElementType
	Content     []byte
}

// ParameterWithConfig contains the parameter and the configuration.
type ParameterWithConfig struct {
	Param  Parameter
	Config []byte // parameter configuration in Proto Buffer format
}

// Gradient is the gradient of the parameter.
type Gradient Parameter

H
Helin Wang 已提交
52
// Service is the RPC service for pserver.
53 54 55 56 57 58
type Service struct {
	initialized chan struct{}

	mu       sync.Mutex
	opt      *optimizer
	paramMap map[string]Parameter
W
wuyi05 已提交
59 60 61 62 63 64 65 66 67 68

	etcdEndpoints string
	etcdClient    *clientv3.Client
	// etcdTimeout is also used as retry intervals.
	etcdTimeout time.Duration
	// desired number of pservers in the job.
	// assume desired will not change during one training job.
	desired int
	// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
	externalIP string
69 70
}

H
Helin Wang 已提交
71
// NewService creates a new service.
W
wuyi05 已提交
72
func NewService(endpoints string, timeout time.Duration) (*Service, error) {
Q
qiaolongfei 已提交
73
	s := &Service{opt: newOptimizer(sgd, 0.005)}
74 75
	s.paramMap = make(map[string]Parameter)
	s.initialized = make(chan struct{})
W
wuyi05 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
	s.etcdEndpoints = endpoints
	s.etcdTimeout = timeout

	var err error
	s.externalIP, err = utils.GetExternalIP()
	if err != nil {
		return nil, err
	}

	if endpoints != "" {
		// initialize connection to etcd, try
		ep := strings.Split(s.etcdEndpoints, ",")
		for {
			cli, err := clientv3.New(clientv3.Config{
				Endpoints:   ep,
				DialTimeout: s.etcdTimeout,
			})
			if err != nil {
				log.Errorf("connect to etcd error: %v", err)
				time.Sleep(s.etcdTimeout)
				continue
			}
			s.etcdClient = cli
			log.Debugf("inited client to %s", s.etcdEndpoints)
			break
		}
		// wait and set s.desired init value
		for {
			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
			resp, err := s.etcdClient.Get(ctx, "/ps_desired")
			cancel()
			if err != nil {
				log.Errorf("getting /ps_desired error: %v", err)
				time.Sleep(s.etcdTimeout)
				continue
			}
			for _, ev := range resp.Kvs {
				log.Debugf("key: %s, value: %s", ev.Key, ev.Value)
				if string(ev.Key) == "/ps_desired" {
					s.desired, err = strconv.Atoi(string(ev.Value))
					if err != nil {
						log.Errorf("value of /ps_desired invalid %v\n", err)
						time.Sleep(s.etcdTimeout)
						// NOTE: wait util ps_desired value change
						continue
					}
				}
			}
			break
		}
		s.registerPserverEtcd()
	} // if endpoints != ""
	// Bypass etcd registration if no endpoints specified
	return s, nil
}

// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd() (*clientv3.TxnResponse, error) {
	return concurrency.NewSTMRepeatable(context.TODO(), s.etcdClient, func(c concurrency.STM) error {
		for i := 0; i < s.desired; i++ {
			psKey := "/ps/" + strconv.Itoa(i)
			log.Debugf("checking %s", psKey)
			ps := c.Get(psKey)
			log.Debugf("got value (%s) for key: %s", ps, psKey)

			resp, err := s.etcdClient.Grant(context.TODO(), 5)
			if err != nil {
				log.Fatal(err)
			}

			if ps == "" {
				// find the first id and write info
				c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
				log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
				ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
				if kaerr != nil {
					log.Errorf("keepalive etcd node error: %v", kaerr)
					return kaerr
				}
				// FIXME: does this really needed?
				go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
					ka := <-ch
					log.Debugf("keepalive: %d\n", ka.TTL)
				}(ch)
				break
			}
		}
		log.Debug("register finished")
		return nil
	})
166 167
}

H
Helin Wang 已提交
168
// InitParam initializes a parameter.
169 170 171
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
	select {
	case <-s.initialized:
172
		return errors.New(AlreadyInitialized)
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
	default:
	}

	// TODO(helin): parse parameter config

	s.mu.Lock()
	defer s.mu.Unlock()

	// TODO(helin): check if paramWithConfigs.Param.Content is
	// properly memory aligned, if not, make copy to a memory
	// aligned region.
	s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param
	return nil
}

H
Helin Wang 已提交
188 189
// FinishInitParams tells the parameter server that the parameter
// initialization has finished.
190 191 192
func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
	select {
	case <-s.initialized:
193
		return errors.New(AlreadyInitialized)
194 195 196 197 198 199 200
	default:
	}

	close(s.initialized)
	return nil
}

201
// SendGrad sends gradient to parameter servers for parameter
H
Helin Wang 已提交
202
// optimization.
203
func (s *Service) SendGrad(g Gradient, dummy *int) error {
204 205 206
	select {
	case <-s.initialized:
	default:
207
		return errors.New(Uninitialized)
208
	}
209 210

	s.mu.Lock()
H
Helin Wang 已提交
211
	defer s.mu.Unlock()
212

213 214 215
	p, ok := s.paramMap[g.Name]
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", g.Name)
216 217
	}

218
	return s.opt.UpdateParameter(p, g)
219 220
}

221 222
// GetParam gets parameters from the parameter server.
func (s *Service) GetParam(name string, parameter *Parameter) error {
223 224
	<-s.initialized
	s.mu.Lock()
H
Helin Wang 已提交
225
	defer s.mu.Unlock()
226

227 228 229
	p, ok := s.paramMap[name]
	if !ok {
		return fmt.Errorf("parameter: %s does not exist", name)
230 231
	}

232 233 234 235 236 237 238 239
	// The parameter content (a byte slice) may change
	// during RPC serialization due to write from other
	// goroutine, we allow it since mini-batch based deep
	// learning optimization methods are stochastic in
	// nature. This race condition is allowed deliberately
	// to save the program from making a copy of the
	// paramter content.
	*parameter = p
240 241 242
	return nil
}

H
Helin Wang 已提交
243 244
// Save tells the parameter server to save parameters.
func (s *Service) Save(path string, dummy *int) error {
245 246 247 248 249
	<-s.initialized

	// TODO
	return nil
}