package pserver import ( "context" "errors" "fmt" "strconv" "strings" "sync" "time" "github.com/PaddlePaddle/Paddle/go/utils" "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. type ElementType int const ( AlreadyInitialized = "pserver already initialized" Uninitialized = "pserver not fully initialized" ) // Supported element types const ( Int32 ElementType = iota UInt32 Int64 UInt64 Float32 Float64 ) // Parameter is a piece of data to sync with the parameter server. type Parameter struct { Name string ElementType ElementType Content []byte } // ParameterWithConfig contains the parameter and the configuration. type ParameterWithConfig struct { Param Parameter Config []byte // parameter configuration in Proto Buffer format } // Gradient is the gradient of the parameter. type Gradient Parameter // Service is the RPC service for pserver. type Service struct { initialized chan struct{} mu sync.Mutex opt *optimizer paramMap map[string]Parameter etcdEndpoints string etcdClient *clientv3.Client // etcdTimeout is also used as retry intervals. etcdTimeout time.Duration // desired number of pservers in the job. // assume desired will not change during one training job. desired int // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. externalIP string } // NewService creates a new service. func NewService(endpoints string, timeout time.Duration) (*Service, error) { s := &Service{opt: newOptimizer(sgd, 0.005)} s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) s.etcdEndpoints = endpoints s.etcdTimeout = timeout var err error s.externalIP, err = utils.GetExternalIP() if err != nil { return nil, err } if endpoints != "" { // initialize connection to etcd, try ep := strings.Split(s.etcdEndpoints, ",") for { cli, err := clientv3.New(clientv3.Config{ Endpoints: ep, DialTimeout: s.etcdTimeout, }) if err != nil { log.Errorf("connect to etcd error: %v", err) time.Sleep(s.etcdTimeout) continue } s.etcdClient = cli log.Debugf("inited client to %s", s.etcdEndpoints) break } // wait and set s.desired init value for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) resp, err := s.etcdClient.Get(ctx, "/ps_desired") cancel() if err != nil { log.Errorf("getting /ps_desired error: %v", err) time.Sleep(s.etcdTimeout) continue } for _, ev := range resp.Kvs { log.Debugf("key: %s, value: %s", ev.Key, ev.Value) if string(ev.Key) == "/ps_desired" { s.desired, err = strconv.Atoi(string(ev.Value)) if err != nil { log.Errorf("value of /ps_desired invalid %v\n", err) time.Sleep(s.etcdTimeout) // NOTE: wait util ps_desired value change continue } } } break } s.registerPserverEtcd() } // if endpoints != "" // Bypass etcd registration if no endpoints specified return s, nil } // registerPserverEtcd registers pserver node on etcd using transaction. func (s *Service) registerPserverEtcd() (*clientv3.TxnResponse, error) { return concurrency.NewSTMRepeatable(context.TODO(), s.etcdClient, func(c concurrency.STM) error { for i := 0; i < s.desired; i++ { psKey := "/ps/" + strconv.Itoa(i) log.Debugf("checking %s", psKey) ps := c.Get(psKey) log.Debugf("got value (%s) for key: %s", ps, psKey) resp, err := s.etcdClient.Grant(context.TODO(), 5) if err != nil { log.Fatal(err) } if ps == "" { // find the first id and write info c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) if kaerr != nil { log.Errorf("keepalive etcd node error: %v", kaerr) return kaerr } // FIXME: does this really needed? go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { ka := <-ch log.Debugf("keepalive: %d\n", ka.TTL) }(ch) break } } log.Debug("register finished") return nil }) } // InitParam initializes a parameter. func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { select { case <-s.initialized: return errors.New(AlreadyInitialized) default: } // TODO(helin): parse parameter config s.mu.Lock() defer s.mu.Unlock() // TODO(helin): check if paramWithConfigs.Param.Content is // properly memory aligned, if not, make copy to a memory // aligned region. s.paramMap[paramWithConfigs.Param.Name] = paramWithConfigs.Param return nil } // FinishInitParams tells the parameter server that the parameter // initialization has finished. func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { select { case <-s.initialized: return errors.New(AlreadyInitialized) default: } close(s.initialized) return nil } // SendGrad sends gradient to parameter servers for parameter // optimization. func (s *Service) SendGrad(g Gradient, dummy *int) error { select { case <-s.initialized: default: return errors.New(Uninitialized) } s.mu.Lock() defer s.mu.Unlock() p, ok := s.paramMap[g.Name] if !ok { return fmt.Errorf("parameter: %s does not exist", g.Name) } return s.opt.UpdateParameter(p, g) } // GetParam gets parameters from the parameter server. func (s *Service) GetParam(name string, parameter *Parameter) error { <-s.initialized s.mu.Lock() defer s.mu.Unlock() p, ok := s.paramMap[name] if !ok { return fmt.Errorf("parameter: %s does not exist", name) } // The parameter content (a byte slice) may change // during RPC serialization due to write from other // goroutine, we allow it since mini-batch based deep // learning optimization methods are stochastic in // nature. This race condition is allowed deliberately // to save the program from making a copy of the // paramter content. *parameter = p return nil } // Save tells the parameter server to save parameters. func (s *Service) Save(path string, dummy *int) error { <-s.initialized // TODO return nil }