From 4cc9680cc60296f6071fa34893fda4f3d6806b97 Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Wed, 28 Jun 2017 01:16:28 +0000 Subject: [PATCH] Make pserver able to get server index without etcd (decouple pserver with etcd) The pserver need to have server index for saving model on the distributed file system. The server index comes from etcd if etcd is used, or user can manually specify them. So we need pserver.NewService() to take index as an argument. Since index could come from etcd, it would be strange if pserver takes an index as argument, at the same time get the index from etcd. so we will need to decouple pserver with etcd. --- go/cmd/pserver/pserver.go | 8 +- go/master/etcd_client.go | 4 +- go/pserver/client_test.go | 3 +- go/pserver/etcd_client.go | 181 +++++++++++++++++++++++++++++++++++++ go/pserver/service.go | 156 ++------------------------------ go/pserver/service_test.go | 8 +- 6 files changed, 201 insertions(+), 159 deletions(-) create mode 100644 go/pserver/etcd_client.go diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 6c85b1804bb..8a42d4f8af1 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -30,7 +30,13 @@ func main() { log.SetLevel(level) timeout := time.Second * time.Duration((*etcdTimeout)) - s, err := pserver.NewService(*etcdEndpoint, *numPservers, timeout) + e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + idx, err := e.Register() + if err != nil { + panic(err) + } + + s, err := pserver.NewService(idx) if err != nil { panic(err) } diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go index b7293a75989..f7b46385773 100644 --- a/go/master/etcd_client.go +++ b/go/master/etcd_client.go @@ -18,8 +18,8 @@ const ( DefaultAddrPath = "/master/addr" ) -// EtcdClient is the etcd client that master uses for fault tolerance -// and service registry. +// EtcdClient is the etcd client that the master uses for fault +// tolerance and service registry. type EtcdClient struct { lockPath string statePath string diff --git a/go/pserver/client_test.go b/go/pserver/client_test.go index 4a62ae88a4a..5bd16118a7f 100644 --- a/go/pserver/client_test.go +++ b/go/pserver/client_test.go @@ -7,7 +7,6 @@ import ( "strconv" "strings" "testing" - "time" "github.com/PaddlePaddle/Paddle/go/pserver" ) @@ -31,7 +30,7 @@ func init() { port[i] = p go func(l net.Listener) { - s, err := pserver.NewService("", 1, time.Second*5) + s, err := pserver.NewService(0) if err != nil { panic(err) } diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go new file mode 100644 index 00000000000..4d88243edd4 --- /dev/null +++ b/go/pserver/etcd_client.go @@ -0,0 +1,181 @@ +package pserver + +import ( + "context" + "errors" + "strconv" + "strings" + "time" + + "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" + "github.com/coreos/etcd/clientv3" + "github.com/coreos/etcd/clientv3/concurrency" + log "github.com/sirupsen/logrus" +) + +// EtcdClient is the etcd client that the pserver uses for fault +// tolerance, service registry and coordination. +type EtcdClient struct { + numPservers int + etcdEndpoints string + etcdClient *clientv3.Client + // etcdTimeout is also used as retry intervals. + etcdTimeout time.Duration + // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. + externalIP string + // desired number of pservers in the job. + // assume desired will not change during one training job. + desired int +} + +// NewEtcdClient creates an EtcdClient +func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient { + return &EtcdClient{ + etcdTimeout: timeout, + numPservers: numPservers, + etcdEndpoints: endpoints, + } +} + +// Register registers the pserver on etcd +// +// Register returns the index of the current pserver. +func (e *EtcdClient) Register() (int, error) { + + var err error + e.externalIP, err = networkhelper.GetExternalIP() + if err != nil { + return 0, err + } + + // initialize connection to etcd. + ep := strings.Split(e.etcdEndpoints, ",") + for { + cli, err := clientv3.New(clientv3.Config{ + Endpoints: ep, + DialTimeout: e.etcdTimeout, + }) + if err != nil { + log.Errorf("connect to etcd error: %v", err) + time.Sleep(e.etcdTimeout) + continue + } + e.etcdClient = cli + log.Debugf("inited client to %s", e.etcdEndpoints) + break + } + // init /ps_desired using transaction, for multiple pservers may want to write + // it at the same time. + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + _, err := e.initDesiredPsercers(ctx, e.numPservers) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(e.etcdTimeout) + continue + } + break + } + // TODO: when implementing extending or reducing pservers, /ps_desired is + // changed, then we need to watch /ps_desired node for events. For now, just + // write once when init and read from it. + // wait and set s.desired init value + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + resp, err := e.etcdClient.Get(ctx, PsDesired) + cancel() + if err != nil { + log.Errorf("getting %s error: %v", PsDesired, err) + time.Sleep(e.etcdTimeout) + continue + } + if len(resp.Kvs) != 0 { + e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) + if err != nil { + log.Errorf("value of %s invalid %v\n", PsDesired, err) + time.Sleep(e.etcdTimeout) + // NOTE: wait util ps_desired value change + continue + } + break + } + } + + var pserverIdx int + // try register pserver node on etcd + for { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + var err error + pserverIdx, err = e.registerPserverEtcd(ctx) + cancel() + if err != nil { + log.Warn(err) + time.Sleep(e.etcdTimeout) + continue + } + break + } + + return pserverIdx, nil +} + +func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { + return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { + dsStr := c.Get(PsDesired) + if dsStr == "" { + c.Put(PsDesired, strconv.Itoa(numPservers)) + } + return nil + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) +} + +// registerPserverEtcd registers pserver node on etcd using transaction. +func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { + var idx int + _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { + registered := false + for i := 0; i < e.desired; i++ { + psKey := "/ps/" + strconv.Itoa(i) + log.Debugf("checking %s", psKey) + ps := c.Get(psKey) + log.Debugf("got value (%s) for key: %s", ps, psKey) + + if ps == "" { + resp, err := e.etcdClient.Grant(context.TODO(), 5) + if err != nil { + log.Fatal(err) + } + // find the first id and write info + c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID)) + log.Debugf("set pserver node %s with value %s", psKey, e.externalIP) + ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID) + if kaerr != nil { + log.Errorf("keepalive etcd node error: %v", kaerr) + return kaerr + } + + // Eat the keep alive message so etcd + // will not expire the lease. + go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { + ka := <-ch + log.Debugf("keepalive: %d\n", ka.TTL) + }(ch) + log.Debug("register finished") + idx = i + registered = true + break + } + } + if registered == true { + return nil + } + return errors.New("not registerd, may due to already have enough pservers") + }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) + + if err != nil { + return 0, err + } + + return idx, nil +} diff --git a/go/pserver/service.go b/go/pserver/service.go index f966595fdcc..f386ebea1eb 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -1,18 +1,9 @@ package pserver import ( - "context" "errors" "fmt" - "strconv" - "strings" "sync" - "time" - - "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" - "github.com/coreos/etcd/clientv3" - "github.com/coreos/etcd/clientv3/concurrency" - log "github.com/sirupsen/logrus" ) // ElementType is the type of elements of a Parameter. @@ -55,160 +46,25 @@ type Gradient Parameter // Service is the RPC service for pserver. type Service struct { initialized chan struct{} + idx int mu sync.Mutex opt *optimizer paramMap map[string]Parameter - - etcdEndpoints string - etcdClient *clientv3.Client - // etcdTimeout is also used as retry intervals. - etcdTimeout time.Duration - // desired number of pservers in the job. - // assume desired will not change during one training job. - desired int - // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. - externalIP string } // NewService creates a new service, will bypass etcd registration if no // endpoints specified. -func NewService(endpoints string, numPservers int, timeout time.Duration) (*Service, error) { - s := &Service{opt: newOptimizer(sgd, 0.005)} +func NewService(idx int) (*Service, error) { + s := &Service{ + idx: idx, + opt: newOptimizer(sgd, 0.005), + } s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) - s.etcdEndpoints = endpoints - s.etcdTimeout = timeout - - var err error - s.externalIP, err = networkhelper.GetExternalIP() - if err != nil { - return nil, err - } - - if endpoints != "" { - // initialize connection to etcd, try - ep := strings.Split(s.etcdEndpoints, ",") - for { - cli, err := clientv3.New(clientv3.Config{ - Endpoints: ep, - DialTimeout: s.etcdTimeout, - }) - if err != nil { - log.Errorf("connect to etcd error: %v", err) - time.Sleep(s.etcdTimeout) - continue - } - s.etcdClient = cli - log.Debugf("inited client to %s", s.etcdEndpoints) - break - } - // init /ps_desired using transaction, for multiple pservers may want to write - // it at the same time. - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := s.initDesiredPsercers(ctx, numPservers) - cancel() - if err != nil { - log.Warn(err) - time.Sleep(s.etcdTimeout) - continue - } - break - } - // TODO: when implementing extending or reducing pservers, /ps_desired is - // changed, then we need to watch /ps_desired node for events. For now, just - // write once when init and read from it. - // wait and set s.desired init value - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - resp, err := s.etcdClient.Get(ctx, PsDesired) - cancel() - if err != nil { - log.Errorf("getting %s error: %v", PsDesired, err) - time.Sleep(s.etcdTimeout) - continue - } - if len(resp.Kvs) != 0 { - s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) - if err != nil { - log.Errorf("value of %s invalid %v\n", PsDesired, err) - time.Sleep(s.etcdTimeout) - // NOTE: wait util ps_desired value change - continue - } - break - } - } - // try register pserver node on etcd - for { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := s.registerPserverEtcd(ctx) - cancel() - if err != nil { - log.Warn(err) - time.Sleep(s.etcdTimeout) - continue - } - break - } - } // if endpoints != "" - // Bypass etcd registration if no endpoints specified return s, nil } -func (s *Service) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { - return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { - dsStr := c.Get(PsDesired) - if dsStr == "" { - c.Put(PsDesired, strconv.Itoa(numPservers)) - } - return nil - }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) -} - -// registerPserverEtcd registers pserver node on etcd using transaction. -func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) { - return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error { - registered := false - for i := 0; i < s.desired; i++ { - psKey := "/ps/" + strconv.Itoa(i) - log.Debugf("checking %s", psKey) - ps := c.Get(psKey) - log.Debugf("got value (%s) for key: %s", ps, psKey) - - if ps == "" { - resp, err := s.etcdClient.Grant(context.TODO(), 5) - if err != nil { - log.Fatal(err) - } - // find the first id and write info - c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID)) - log.Debugf("set pserver node %s with value %s", psKey, s.externalIP) - ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID) - if kaerr != nil { - log.Errorf("keepalive etcd node error: %v", kaerr) - return kaerr - } - - // Eat the keep alive message so etcd - // will not expire the lease. - go func(ch <-chan *clientv3.LeaseKeepAliveResponse) { - ka := <-ch - log.Debugf("keepalive: %d\n", ka.TTL) - }(ch) - log.Debug("register finished") - registered = true - break - } - } - if registered == true { - return nil - } - return errors.New("not registerd, may due to already have enough pservers") - }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) -} - // InitParam initializes a parameter. func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { select { diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 1d84f15d78a..d9d887cffd4 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -10,7 +10,7 @@ import ( ) func TestFull(t *testing.T) { - s, err := pserver.NewService("", 1, time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } @@ -75,7 +75,7 @@ func TestFull(t *testing.T) { } func TestMultipleInit(t *testing.T) { - s, err := pserver.NewService("", 1, time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } @@ -91,7 +91,7 @@ func TestMultipleInit(t *testing.T) { } func TestUninitialized(t *testing.T) { - s, err := pserver.NewService("", 1, time.Second*5) + s, err := pserver.NewService(0) err = s.SendGrad(pserver.Gradient{}, nil) if err.Error() != pserver.Uninitialized { t.FailNow() @@ -99,7 +99,7 @@ func TestUninitialized(t *testing.T) { } func TestBlockUntilInitialized(t *testing.T) { - s, err := pserver.NewService("", 1, time.Second*5) + s, err := pserver.NewService(0) if err != nil { t.Error(err) } -- GitLab