diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index fe1fe5f6f03b315cf30d96e171dca53e32efa040..6c85b1804bb9c5f3a8bc46bb3f54cc62c56cca70 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -18,6 +18,7 @@ func main() { etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") + numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job") logLevel := flag.String("log-level", "info", "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() @@ -29,7 +30,7 @@ func main() { log.SetLevel(level) timeout := time.Second * time.Duration((*etcdTimeout)) - s, err := pserver.NewService(*etcdEndpoint, timeout) + s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout) if err != nil { panic(err) } diff --git a/go/pserver/service.go b/go/pserver/service.go index 7e2b841dd8e8efda8b3033645b4b1b6a596e53c4..f966595fdccbf23e23f94a857503ce05815164ef 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -73,7 +73,7 @@ type Service struct { // NewService creates a new service, will bypass etcd registration if no // endpoints specified. -func NewService(endpoints string, timeout time.Duration) (*Service, error) { +func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) { s := &Service{opt: newOptimizer(sgd, 0.005)} s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) @@ -103,6 +103,22 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) { log.Debugf("inited client to %s", s.etcdEndpoints) break } + // init /ps_desired using transaction, for multiple pservers may want to write + // it at the same time. + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := s.initDesiredPsercers(ctx, numPservers) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(s.etcdTimeout) + continue + } + break + } + // TODO: when implementing extending or reducing pservers, /ps_desired is + // changed, then we need to watch /ps_desired node for events. For now, just + // write once when init and read from it. // wait and set s.desired init value for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -141,6 +157,16 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) { return s, nil } +func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { + return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { + dsStr := c.Get(PsDesired) + if dsStr == "" { + c.Put(PsDesired, strconv.Itoa(numPservers)) + } + return nil + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) +} + // registerPserverEtcd registers pserver node on etcd using transaction. func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) { return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {