From 0b936e9399f2a5f01f6fde1d1b78b56306a8f9ac Mon Sep 17 00:00:00 2001 From: wuyi05 Date: Thu, 22 Jun 2017 15:00:39 +0800 Subject: [PATCH] update pserver etcd --- go/cmd/pserver/pserver.go | 3 +- go/pserver/service.go | 75 ++++++++++++--------- go/utils/{ => networkhelper}/helper.go | 2 +- go/utils/{ => networkhelper}/helper_test.go | 2 +- 4 files changed, 47 insertions(+), 35 deletions(-) rename go/utils/{ => networkhelper}/helper.go (97%) rename go/utils/{ => networkhelper}/helper_test.go (87%) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index f42c90c6c6d..fe1fe5f6f03 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -18,7 +18,8 @@ func main() { etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") - logLevel := flag.String("log-level", "info", "log level, one of debug") + logLevel := flag.String("log-level", "info", + "log level, possible values: debug, info, warning, error, fatal, panic") flag.Parse() level, err := log.ParseLevel(*logLevel) diff --git a/go/pserver/service.go b/go/pserver/service.go index a5c76857abe..7400b488325 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -9,7 +9,7 @@ import ( "sync" "time" - "github.com/PaddlePaddle/Paddle/go/utils" + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" log "github.com/sirupsen/logrus" @@ -33,6 +33,9 @@ const ( Float64 ) +// PsDesired is etcd path for store desired pserver count +const PsDesired = "/ps_desired" + // Parameter is a piece of data to sync with the parameter server. type Parameter struct { Name string @@ -68,7 +71,8 @@ type Service struct { externalIP string } -// NewService creates a new service. +// NewService creates a new service, will bypass etcd registration if no +// endpoints specified. func NewService(endpoints string, timeout time.Duration) (*Service, error) { s := &Service{opt: newOptimizer(sgd, 0.005)} s.paramMap = make(map[string]Parameter) @@ -77,7 +81,7 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) { s.etcdTimeout = timeout var err error - s.externalIP, err = utils.GetExternalIP() + s.externalIP, err = networkhelper.GetExternalIP() if err != nil { return nil, err } @@ -102,67 +106,74 @@ func NewService(endpoints string, timeout time.Duration) (*Service, error) { // wait and set s.desired init value for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - resp, err := s.etcdClient.Get(ctx, "/ps_desired") + resp, err := s.etcdClient.Get(ctx, PsDesired) cancel() if err != nil { - log.Errorf("getting /ps_desired error: %v", err) + log.Errorf("getting %s error: %v", PsDesired, err) time.Sleep(s.etcdTimeout) continue } - for _, ev := range resp.Kvs { - log.Debugf("key: %s, value: %s", ev.Key, ev.Value) - if string(ev.Key) == "/ps_desired" { - s.desired, err = strconv.Atoi(string(ev.Value)) - if err != nil { - log.Errorf("value of /ps_desired invalid %v\n", err) - time.Sleep(s.etcdTimeout) - // NOTE: wait util ps_desired value change - continue - } + if len(resp.Kvs) != 0 { + s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) + if err != nil { + log.Errorf("value of %s invalid %v\n", PsDesired, err) + time.Sleep(s.etcdTimeout) + // NOTE: wait util ps_desired value change + continue } + break + } + } + // try register pserver node on etcd + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := s.registerPserverEtcd(ctx) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(s.etcdTimeout) + continue } break } - s.registerPserverEtcd() } // if endpoints != "" // Bypass etcd registration if no endpoints specified return s, nil } // registerPserverEtcd registers pserver node on etcd using transaction. -func (s *Service) registerPserverEtcd() (*clientv3.TxnResponse, error) { - return concurrency.NewSTMRepeatable(context.TODO(), s.etcdClient, func(c concurrency.STM) error { +func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) { + return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { + registered := false for i := 0; i < s.desired; i++ { psKey := "/ps/" + strconv.Itoa(i) log.Debugf("checking %s", psKey) ps := c.Get(psKey) log.Debugf("got value (%s) for key: %s", ps, psKey) - resp, err := s.etcdClient.Grant(context.TODO(), 5) - if err != nil { - log.Fatal(err) - } - if ps == "" { + resp, err := s.etcdClient.Grant(context.TODO(), 5) + if err != nil { + log.Fatal(err) + } // find the first id and write info c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) - ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) + _, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) if kaerr != nil { log.Errorf("keepalive etcd node error: %v", kaerr) return kaerr } - // FIXME: does this really needed? - go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { - ka := <-ch - log.Debugf("keepalive: %d\n", ka.TTL) - }(ch) + log.Debug("register finished") + registered = true break } } - log.Debug("register finished") - return nil - }) + if registered == true { + return nil + } + return errors.New("not registerd, may due to already have enough pservers") + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) } // InitParam initializes a parameter. diff --git a/go/utils/helper.go b/go/utils/networkhelper/helper.go similarity index 97% rename from go/utils/helper.go rename to go/utils/networkhelper/helper.go index 3220fd6c78b..fbeaea8f5e7 100644 --- a/go/utils/helper.go +++ b/go/utils/networkhelper/helper.go @@ -1,4 +1,4 @@ -package utils +package networkhelper import ( "errors" diff --git a/go/utils/helper_test.go b/go/utils/networkhelper/helper_test.go similarity index 87% rename from go/utils/helper_test.go rename to go/utils/networkhelper/helper_test.go index aa7c509768e..4208f9e358f 100644 --- a/go/utils/helper_test.go +++ b/go/utils/networkhelper/helper_test.go @@ -1,4 +1,4 @@ -package utils +package networkhelper import "testing" -- GitLab