etcd_client.go 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
package pserver

import (
	"context"
	"errors"
	"strconv"
	"strings"
	"time"

	"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
	"github.com/coreos/etcd/clientv3"
	"github.com/coreos/etcd/clientv3/concurrency"
	log "github.com/sirupsen/logrus"
)

Q
Qiao Longfei 已提交
16 17 18
const (
	// PsDesired is etcd path for store desired pserver count
	PsDesired = "/ps_desired"
19
	// PsPath is the base dir for pserver to store their addr
Q
Qiao Longfei 已提交
20
	PsPath = "/ps/"
D
dongzhihong 已提交
21 22
	// PsCheckpoint is the etcd path for store checkpoints information
	PsCheckpoint = "/checkpoints/"
Q
Qiao Longfei 已提交
23 24
)

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
	numPservers   int
	etcdEndpoints string
	etcdClient    *clientv3.Client
	// etcdTimeout is also used as retry intervals.
	etcdTimeout time.Duration
	// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
	externalIP string
	// desired number of pservers in the job.
	// assume desired will not change during one training job.
	desired int
}

// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
	return &EtcdClient{
		etcdTimeout:   timeout,
		numPservers:   numPservers,
		etcdEndpoints: endpoints,
	}
}

// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
52
func (e *EtcdClient) Register(port int) (int, error) {
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

	var err error
	e.externalIP, err = networkhelper.GetExternalIP()
	if err != nil {
		return 0, err
	}

	// initialize connection to etcd.
	ep := strings.Split(e.etcdEndpoints, ",")
	for {
		cli, err := clientv3.New(clientv3.Config{
			Endpoints:   ep,
			DialTimeout: e.etcdTimeout,
		})
		if err != nil {
			log.Errorf("connect to etcd error: %v", err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		e.etcdClient = cli
		log.Debugf("inited client to %s", e.etcdEndpoints)
		break
	}
	// init /ps_desired using transaction, for multiple pservers may want to write
	// it at the same time.
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
Q
Qiao Longfei 已提交
80
		_, err := e.initDesiredPservers(ctx, e.numPservers)
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
		cancel()
		if err != nil {
			log.Warn(err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		break
	}
	// TODO: when implementing extending or reducing pservers, /ps_desired is
	// changed, then we need to watch /ps_desired node for events. For now, just
	// write once when init and read from it.
	// wait and set s.desired init value
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		resp, err := e.etcdClient.Get(ctx, PsDesired)
		cancel()
		if err != nil {
			log.Errorf("getting %s error: %v", PsDesired, err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		if len(resp.Kvs) != 0 {
			e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
			if err != nil {
				log.Errorf("value of %s invalid %v\n", PsDesired, err)
				time.Sleep(e.etcdTimeout)
				// NOTE: wait util ps_desired value change
				continue
			}
			break
		}
	}

	var pserverIdx int
	// try register pserver node on etcd
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		var err error
119
		pserverIdx, err = e.registerPserverEtcd(ctx, port)
120 121 122 123 124 125 126 127 128 129 130 131
		cancel()
		if err != nil {
			log.Warn(err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		break
	}

	return pserverIdx, nil
}

Q
Qiao Longfei 已提交
132
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
133 134 135 136 137 138 139 140 141 142
	return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
		dsStr := c.Get(PsDesired)
		if dsStr == "" {
			c.Put(PsDesired, strconv.Itoa(numPservers))
		}
		return nil
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}

// registerPserverEtcd registers pserver node on etcd using transaction.
143
func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
144 145 146 147
	var idx int
	_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
		registered := false
		for i := 0; i < e.desired; i++ {
Q
Qiao Longfei 已提交
148
			psKey := PsPath + strconv.Itoa(i)
149 150 151 152 153 154 155 156 157 158
			log.Debugf("checking %s", psKey)
			ps := c.Get(psKey)
			log.Debugf("got value (%s) for key: %s", ps, psKey)

			if ps == "" {
				resp, err := e.etcdClient.Grant(context.TODO(), 5)
				if err != nil {
					log.Fatal(err)
				}
				// find the first id and write info
159 160 161
				pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
				c.Put(psKey, pserverAddr, clientv3.WithLease(resp.ID))
				log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
				ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
				if kaerr != nil {
					log.Errorf("keepalive etcd node error: %v", kaerr)
					return kaerr
				}

				// Eat the keep alive message so etcd
				// will not expire the lease.
				go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
					ka := <-ch
					log.Debugf("keepalive: %d\n", ka.TTL)
				}(ch)
				log.Debug("register finished")
				idx = i
				registered = true
				break
			}
		}
180
		if registered {
181 182
			return nil
		}
183
		return errors.New("not registered, may due to already have enough pservers")
184 185 186 187 188 189 190 191
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))

	if err != nil {
		return 0, err
	}

	return idx, nil
}
D
dongzhihong 已提交
192

193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	resp, err := e.etcdClient.Get(ctx, key)
	cancel()
	if err != nil {
		return []byte{}, err
	}
	kvs := resp.Kvs
	if len(kvs) == 0 {
		return []byte{}, nil
	}
	v := kvs[0].Value
	return v, nil
}

D
dongzhihong 已提交
209
// PutKey put into etcd with value by key specified
210 211
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
D
dongzhihong 已提交
212
	_, err := e.etcdClient.Put(ctx, key, string(value))
D
dongzhihong 已提交
213
	cancel()
214
	return err
D
dongzhihong 已提交
215
}