etcd_client.go 2.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
package client

import (
	"context"
	"strconv"
	"strings"
	"time"

	"github.com/PaddlePaddle/Paddle/go/pserver"
	"github.com/coreos/etcd/clientv3"
	log "github.com/sirupsen/logrus"
)

const (
	DefaultEtcdTimeout time.Duration = 5 * time.Second
)

// EtcdClient is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers)
// 1. add etcd lock)
type EtcdClient struct {
	client    *clientv3.Client
	timeout   time.Duration
	endpoints []string
}

// Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int {
	var psDesired int
	for {
		ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
		resp, err := p.client.Get(ctx, pserver.PsDesired)
		cancel()
		if err != nil {
			log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
			time.Sleep(p.timeout)
			continue
		}

		kvs := resp.Kvs
		if len(kvs) == 0 {
			log.Infoln("Waiting for ps desired registered ...")
			time.Sleep(p.timeout)
			continue
		}

		psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
		if err != nil {
			log.Errorf("psDesired %s invalid %v", psDesired, err)
			time.Sleep(p.timeout)
			continue
		}

		log.Debugf("Get psDesired number: %d", psDesired)
		break
	}
	return psDesired
}

// List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server {
	psDesired := p.Desired()

	servers := make([]Server, psDesired)
	for {
		for i := 0; i < psDesired; i++ {
			ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
			cancel()
			psKey := pserver.PsPath + strconv.Itoa(i)
			log.Debugf("checking %s", psKey)
			resp, err := p.client.Get(ctx, psKey)
			if err != nil {
				log.Infof("Get psKey= %s error, %v", psKey, err)
				time.Sleep(p.timeout)
				continue
			}
			kvs := resp.Kvs
			if len(kvs) == 0 {
				log.Infof("Waiting for ps addr registered ...")
				time.Sleep(p.timeout)
				continue
			}

			psAddr := string(resp.Kvs[0].Value)
			// TODO(Longfei) check the ps address
			if psAddr == "" {
				log.Infof("Get psKey = %s, psAddr is empty", psKey)
				time.Sleep(p.timeout)
				continue
			}
			log.Infof("got value (%s) for key: %s", psAddr, psKey)
			servers[i].Index = i
			servers[i].Addr = psAddr
		}
		break
	}
	return servers
}

// NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient {
	ep := strings.Split(endpoints, ",")
	var cli *clientv3.Client
	var err error
	for {
		cli, err = clientv3.New(clientv3.Config{
			Endpoints:   ep,
			DialTimeout: DefaultEtcdTimeout,
		})
		if err != nil {
			log.Errorf("Init etcd connection failed: %v", err)
			time.Sleep(DefaultEtcdTimeout)
			continue
		}
		break
	}
	log.Infof("Connected to etcd: %s\n", endpoints)
	client := &EtcdClient{
		client:    cli,
		timeout:   DefaultEtcdTimeout,
		endpoints: ep,
	}
	return client
}