etcd_client.go 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
package pserver

import (
	"context"
	"errors"
	"strconv"
	"strings"
	"time"

	"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
	"github.com/coreos/etcd/clientv3"
	"github.com/coreos/etcd/clientv3/concurrency"
	log "github.com/sirupsen/logrus"
)

Q
Qiao Longfei 已提交
16 17 18
const (
	// PsDesired is etcd path for store desired pserver count
	PsDesired = "/ps_desired"
19
	// PsPath is the base dir for pserver to store their addr
Q
Qiao Longfei 已提交
20
	PsPath = "/ps/"
D
dongzhihong 已提交
21 22
	// PsCheckpoint is the etcd path for store checkpoints information
	PsCheckpoint = "/checkpoints/"
Q
Qiao Longfei 已提交
23 24
)

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
	numPservers   int
	etcdEndpoints string
	etcdClient    *clientv3.Client
	// etcdTimeout is also used as retry intervals.
	etcdTimeout time.Duration
	// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
	externalIP string
	// desired number of pservers in the job.
	// assume desired will not change during one training job.
	desired int
}

// NewEtcdClient creates an EtcdClient
func NewEtcdClient(endpoints string, numPservers int, timeout time.Duration) *EtcdClient {
	return &EtcdClient{
		etcdTimeout:   timeout,
		numPservers:   numPservers,
		etcdEndpoints: endpoints,
	}
}

// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
func (e *EtcdClient) Register() (int, error) {

	var err error
	e.externalIP, err = networkhelper.GetExternalIP()
	if err != nil {
		return 0, err
	}

	// initialize connection to etcd.
	ep := strings.Split(e.etcdEndpoints, ",")
	for {
		cli, err := clientv3.New(clientv3.Config{
			Endpoints:   ep,
			DialTimeout: e.etcdTimeout,
		})
		if err != nil {
			log.Errorf("connect to etcd error: %v", err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		e.etcdClient = cli
		log.Debugf("inited client to %s", e.etcdEndpoints)
		break
	}
	// init /ps_desired using transaction, for multiple pservers may want to write
	// it at the same time.
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
Q
Qiao Longfei 已提交
80
		_, err := e.initDesiredPservers(ctx, e.numPservers)
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
		cancel()
		if err != nil {
			log.Warn(err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		break
	}
	// TODO: when implementing extending or reducing pservers, /ps_desired is
	// changed, then we need to watch /ps_desired node for events. For now, just
	// write once when init and read from it.
	// wait and set s.desired init value
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		resp, err := e.etcdClient.Get(ctx, PsDesired)
		cancel()
		if err != nil {
			log.Errorf("getting %s error: %v", PsDesired, err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		if len(resp.Kvs) != 0 {
			e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
			if err != nil {
				log.Errorf("value of %s invalid %v\n", PsDesired, err)
				time.Sleep(e.etcdTimeout)
				// NOTE: wait util ps_desired value change
				continue
			}
			break
		}
	}

	var pserverIdx int
	// try register pserver node on etcd
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		var err error
		pserverIdx, err = e.registerPserverEtcd(ctx)
		cancel()
		if err != nil {
			log.Warn(err)
			time.Sleep(e.etcdTimeout)
			continue
		}
		break
	}

	return pserverIdx, nil
}

Q
Qiao Longfei 已提交
132
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
	return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
		dsStr := c.Get(PsDesired)
		if dsStr == "" {
			c.Put(PsDesired, strconv.Itoa(numPservers))
		}
		return nil
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}

// registerPserverEtcd registers pserver node on etcd using transaction.
func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
	var idx int
	_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
		registered := false
		for i := 0; i < e.desired; i++ {
Q
Qiao Longfei 已提交
148
			psKey := PsPath + strconv.Itoa(i)
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
			log.Debugf("checking %s", psKey)
			ps := c.Get(psKey)
			log.Debugf("got value (%s) for key: %s", ps, psKey)

			if ps == "" {
				resp, err := e.etcdClient.Grant(context.TODO(), 5)
				if err != nil {
					log.Fatal(err)
				}
				// find the first id and write info
				c.Put(psKey, e.externalIP, clientv3.WithLease(resp.ID))
				log.Debugf("set pserver node %s with value %s", psKey, e.externalIP)
				ch, kaerr := e.etcdClient.KeepAlive(context.TODO(), resp.ID)
				if kaerr != nil {
					log.Errorf("keepalive etcd node error: %v", kaerr)
					return kaerr
				}

				// Eat the keep alive message so etcd
				// will not expire the lease.
				go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
					ka := <-ch
					log.Debugf("keepalive: %d\n", ka.TTL)
				}(ch)
				log.Debug("register finished")
				idx = i
				registered = true
				break
			}
		}
		if registered == true {
			return nil
		}
		return errors.New("not registerd, may due to already have enough pservers")
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))

	if err != nil {
		return 0, err
	}

	return idx, nil
}
D
dongzhihong 已提交
191

192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
	resp, err := e.etcdClient.Get(ctx, key)
	cancel()
	if err != nil {
		return []byte{}, err
	}
	kvs := resp.Kvs
	if len(kvs) == 0 {
		return []byte{}, nil
	}
	v := kvs[0].Value
	return v, nil
}

D
dongzhihong 已提交
208
// PutKey put into etcd with value by key specified
209 210
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
D
dongzhihong 已提交
211
	_, err := e.etcdClient.Put(ctx, key, string(value))
D
dongzhihong 已提交
212 213 214 215 216 217
	cancel()
	if err != nil {
		return err
	}
	return nil
}