// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // http://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pserver import ( "context" "errors" "strconv" "strings" "time" "github.com/PaddlePaddle/Paddle/go/utils/networkhelper" "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3/concurrency" log "github.com/sirupsen/logrus" ) const ( // PsDesired is etcd path for store desired pserver count PsDesired = "/ps_desired" // PsPath is the base dir for pserver to store their addr PsPath = "/ps/" // PsCheckpoint is the etcd path for store checkpoints information PsCheckpoint = "/checkpoints/" retryTimeout = 5 * time.Second ) // EtcdClient is the etcd client that the pserver uses for fault // tolerance, service registry and coordination. type EtcdClient struct { numPservers int endpoints string client *clientv3.Client sess *concurrency.Session dialTimeout time.Duration ttlSec int // FIXME: ensure GetExternalIP gets the correct ip for trainers to connect. externalIP string // desired number of pservers in the job. // assume desired will not change during one training job. desired int } // NewEtcdClient creates an EtcdClient func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient { return &EtcdClient{ dialTimeout: dialtimeout, ttlSec: ttlSec, numPservers: numPservers, endpoints: endpoints, } } // Register registers the pserver on etcd // // Register returns the index of the current pserver. func (e *EtcdClient) Register(port int) (int, error) { var err error e.externalIP, err = networkhelper.GetExternalIP() if err != nil { return 0, err } // initialize connection to etcd. ep := strings.Split(e.endpoints, ",") for { cli, err := clientv3.New(clientv3.Config{ Endpoints: ep, DialTimeout: e.dialTimeout, }) if err != nil { log.Errorf("connect to etcd error: %v", err) time.Sleep(retryTimeout) continue } e.client = cli sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec)) if err != nil { log.Errorf("create etcd session error: %v", err) time.Sleep(retryTimeout) continue } e.sess = sess log.Debugf("inited client to %s", e.endpoints) break } // init /ps_desired using transaction, for multiple pservers may want to write // it at the same time. for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) _, err := e.initDesiredPservers(ctx, e.numPservers) cancel() if err != nil { log.Warn(err) time.Sleep(retryTimeout) continue } break } // TODO: when implementing extending or reducing pservers, /ps_desired is // changed, then we need to watch /ps_desired node for events. For now, just // write once when init and read from it. // wait and set s.desired init value for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) resp, err := e.client.Get(ctx, PsDesired) cancel() if err != nil { log.Errorf("getting %s error: %v", PsDesired, err) time.Sleep(retryTimeout) continue } if len(resp.Kvs) != 0 { e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value)) if err != nil { log.Errorf("value of %s invalid %v\n", PsDesired, err) time.Sleep(retryTimeout) // NOTE: wait util ps_desired value change continue } break } } var pserverIdx int // try register pserver node on etcd for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) var err error pserverIdx, err = e.registerPserverEtcd(ctx, port) cancel() if err != nil { log.Warn(err) time.Sleep(retryTimeout) continue } break } return pserverIdx, nil } func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { return concurrency.NewSTM(e.client, func(c concurrency.STM) error { dsStr := c.Get(PsDesired) if dsStr == "" { c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease())) } return nil }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) } // registerPserverEtcd registers pserver node on etcd using transaction. func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) { var idx int _, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error { registered := false for i := 0; i < e.desired; i++ { psKey := PsPath + strconv.Itoa(i) log.Debugf("checking %s", psKey) ps := c.Get(psKey) log.Debugf("got value (%s) for key: %s", ps, psKey) if ps == "" { // find the first id and write info pserverAddr := e.externalIP + ":" + strconv.Itoa(port) c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease())) log.Debugf("set pserver node %s with value %s", psKey, pserverAddr) log.Debug("register finished") idx = i registered = true break } } if registered { return nil } return errors.New("not registered, may due to already have enough pservers") }, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads)) if err != nil { return 0, err } return idx, nil } // GetKey gets the value by the specified key func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) { ctx, cancel := context.WithTimeout(context.Background(), timeout) resp, err := e.client.Get(ctx, key) cancel() if err != nil { return []byte{}, err } kvs := resp.Kvs if len(kvs) == 0 { return []byte{}, nil } v := kvs[0].Value return v, nil } // PutKey put into etcd with value by key specified func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error { ctx, cancel := context.WithTimeout(context.Background(), timeout) _, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease())) cancel() return err } // Shutdown shuts down the etcd client gracefully. func (e *EtcdClient) Shutdown() error { var err error if e.sess != nil { err = e.sess.Close() } if e.client != nil { newErr := e.client.Close() if newErr != nil { if err != nil { log.Errorln(newErr) } else { err = newErr } } } return err }