提交 09abfd91 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #2593 from typhoonzero/set_ps_desired

Set ps_desired when pserver init
...@@ -18,6 +18,7 @@ func main() { ...@@ -18,6 +18,7 @@ func main() {
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd") "comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
logLevel := flag.String("log-level", "info", logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic") "log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
...@@ -29,7 +30,7 @@ func main() { ...@@ -29,7 +30,7 @@ func main() {
log.SetLevel(level) log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout)) timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, timeout) s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -73,7 +73,7 @@ type Service struct { ...@@ -73,7 +73,7 @@ type Service struct {
// NewService creates a new service, will bypass etcd registration if no // NewService creates a new service, will bypass etcd registration if no
// endpoints specified. // endpoints specified.
func NewService(endpoints string, timeout time.Duration) (*Service, error) { func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)} s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
...@@ -103,6 +103,22 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) { ...@@ -103,6 +103,22 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) {
log.Debugf("inited client to %s", s.etcdEndpoints) log.Debugf("inited client to %s", s.etcdEndpoints)
break break
} }
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.initDesiredPsercers(ctx, numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value // wait and set s.desired init value
for { for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Second)
...@@ -141,6 +157,16 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) { ...@@ -141,6 +157,16 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) {
return s, nil return s, nil
} }
func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction. // registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) { func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册