提交 09abfd91 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #2593 from typhoonzero/set_ps_desired

Set ps_desired when pserver init
......@@ -18,6 +18,7 @@ func main() {
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
numPservers := flag.Int("num-pservers", 1, "total pserver count in a training job")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse()
......@@ -29,7 +30,7 @@ func main() {
log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, timeout)
s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout)
if err != nil {
panic(err)
}
......
......@@ -73,7 +73,7 @@ type Service struct {
// NewService creates a new service, will bypass etcd registration if no
// endpoints specified.
func NewService(endpoints string, timeout time.Duration) (*Service, error) {
func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{})
......@@ -103,6 +103,22 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) {
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// init /ps_desired using transaction, for multiple pservers may want to write
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.initDesiredPsercers(ctx, numPservers)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
// TODO: when implementing extending or reducing pservers, /ps_desired is
// changed, then we need to watch /ps_desired node for events. For now, just
// write once when init and read from it.
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
......@@ -141,6 +157,16 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) {
return s, nil
}
func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
c.Put(PsDesired, strconv.Itoa(numPservers))
}
return nil
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册