etcd_client.go 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
package pserver

import (
	"context"
	"errors"
	"strconv"
	"strings"
	"time"

	"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
	"github.com/coreos/etcd/clientv3"
	"github.com/coreos/etcd/clientv3/concurrency"
	log "github.com/sirupsen/logrus"
)

Q
Qiao Longfei 已提交
16 17 18 19 20
const (
	// PsDesired is etcd path for store desired pserver count
	PsDesired = "/ps_desired"
	// PsAddr is the base dir for pserver to store their addr
	PsPath = "/ps/"
D
dongzhihong 已提交
21 22
	// PsCheckpoint is the etcd path for store checkpoints information
	PsCheckpoint = "/checkpoints/"
Q
Qiao Longfei 已提交
23 24
)

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
	numPservers   int
	etcdEndpoints string
	etcdClient    *clientv3.Client
	// etcdTimeout is also used as retry intervals.
	etcdTimeout time.Duration
	// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
	externalIP string
	// desired number of pservers in the job.
	// assume desired will not change during one training job.
	desired int
}

// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
	return &EtcdClient{
		etcdTimeout:   timeout,
		numPservers:   numPservers,
		etcdEndpoints: endpoints,
	}
}

// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) {

	var err error
	e.externalIP, err = networkhelper.GetExternalIP()
	if err != nil {
		return 0, err
	}

	// initialize connection to etcd.
	ep := strings.Split(e.etcdEndpoints, ",")
	for {
		cli, err := clientv3.New(clientv3.Config{
			Endpoints:   ep,
			DialTimeout: e.etcdTimeout,
		})
		if err != nil {
			log.Errorf("connect to etcd error: %v", err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		e.etcdClient = cli
		log.Debugf("inited client to %s", e.etcdEndpoints)
		break
	}
	// init /ps_desired using transaction, for multiple pservers may want to write
	// it at the same time.
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
Q
Qiao Longfei 已提交
80
		_, err := e.initDesiredPservers(ctx, e.numPservers)
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
		cancel()
		if err != nil {
			log.Warn(err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		break
	}
	// TODO: when implementing extending or reducing pservers, /ps_desired is
	// changed, then we need to watch /ps_desired node for events. For now, just
	// write once when init and read from it.
	// wait and set s.desired init value
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		resp, err := e.etcdClient.Get(ctx, PsDesired)
		cancel()
		if err != nil {
			log.Errorf("getting %s error: %v", PsDesired, err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		if len(resp.Kvs) != 0 {
			e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
			if err != nil {
				log.Errorf("value of %s invalid %v\n", PsDesired, err)
				time.Sleep(e.etcdTimeout)
				// NOTE: wait util ps_desired value change
				continue
			}
			break
		}
	}

	var pserverIdx int
	// try register pserver node on etcd
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		var err error
		pserverIdx, err = e.registerPserverEtcd(ctx)
		cancel()
		if err != nil {
			log.Warn(err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		break
	}

	return pserverIdx, nil
}

Q
Qiao Longfei 已提交
132
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
	return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
		dsStr := c.Get(PsDesired)
		if dsStr == "" {
			c.Put(PsDesired, strconv.Itoa(numPservers))
		}
		return nil
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}

// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
	var idx int
	_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
		registered := false
		for i := 0; i < e.desired; i++ {
Q
Qiao Longfei 已提交
148
			psKey := PsPath + strconv.Itoa(i)
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
			log.Debugf("checking %s", psKey)
			ps := c.Get(psKey)
			log.Debugf("got value (%s) for key: %s", ps, psKey)

			if ps == "" {
				resp, err := e.etcdClient.Grant(context.TODO(), 5)
				if err != nil {
					log.Fatal(err)
				}
				// find the first id and write info
				c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
				log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
				ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
				if kaerr != nil {
					log.Errorf("keepalive etcd node error: %v", kaerr)
					return kaerr
				}

				// Eat the keep alive message so etcd
				// will not expire the lease.
				go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
					ka := <-ch
					log.Debugf("keepalive: %d\n", ka.TTL)
				}(ch)
				log.Debug("register finished")
				idx = i
				registered = true
				break
			}
		}
		if registered == true {
			return nil
		}
		return errors.New("not registerd, may due to already have enough pservers")
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))

	if err != nil {
		return 0, err
	}

	return idx, nil
}
D
dongzhihong 已提交
191 192 193 194 195 196 197 198 199 200 201

// PutKey put into etcd with value by key specified
func (e *EtcdClient) PutKey(key string, value []byte, timeout int) error {
	ctx, err := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
	_, err = e.Put(ctx, key, value)
	cancel()
	if err != nil {
		return err
	}
	return nil
}