etcd_client.go 6.6 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
package pserver

import (
	"context"
	"errors"
	"strconv"
	"strings"
	"time"

	"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
	"github.com/coreos/etcd/clientv3"
	"github.com/coreos/etcd/clientv3/concurrency"
	log "github.com/sirupsen/logrus"
)

Q
Qiao Longfei 已提交
30 31 32
const (
	// PsDesired is etcd path for store desired pserver count
	PsDesired = "/ps_desired"
33
	// PsPath is the base dir for pserver to store their addr
Q
Qiao Longfei 已提交
34
	PsPath = "/ps/"
D
dongzhihong 已提交
35 36
	// PsCheckpoint is the etcd path for store checkpoints information
	PsCheckpoint = "/checkpoints/"
37 38

	retryTimeout = 5 * time.Second
Q
Qiao Longfei 已提交
39 40
)

41 42 43
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
44 45 46 47 48 49
	numPservers int
	endpoints   string
	client      *clientv3.Client
	sess        *concurrency.Session
	dialTimeout time.Duration
	ttlSec      int
50 51 52 53 54 55 56 57
	// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
	externalIP string
	// desired number of pservers in the job.
	// assume desired will not change during one training job.
	desired int
}

// NewEtcdClient creates an EtcdClient
58
func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient {
59
	return &EtcdClient{
60 61 62 63
		dialTimeout: dialtimeout,
		ttlSec:      ttlSec,
		numPservers: numPservers,
		endpoints:   endpoints,
64 65 66 67 68 69
	}
}

// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
70
func (e *EtcdClient) Register(port int) (int, error) {
71 72 73 74 75 76 77
	var err error
	e.externalIP, err = networkhelper.GetExternalIP()
	if err != nil {
		return 0, err
	}

	// initialize connection to etcd.
78
	ep := strings.Split(e.endpoints, ",")
79 80 81
	for {
		cli, err := clientv3.New(clientv3.Config{
			Endpoints:   ep,
82
			DialTimeout: e.dialTimeout,
83 84 85
		})
		if err != nil {
			log.Errorf("connect to etcd error: %v", err)
86 87 88 89 90 91 92 93
			time.Sleep(retryTimeout)
			continue
		}
		e.client = cli
		sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
		if err != nil {
			log.Errorf("create etcd session error: %v", err)
			time.Sleep(retryTimeout)
94 95
			continue
		}
96 97
		e.sess = sess
		log.Debugf("inited client to %s", e.endpoints)
98 99 100 101 102 103
		break
	}
	// init /ps_desired using transaction, for multiple pservers may want to write
	// it at the same time.
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
Q
Qiao Longfei 已提交
104
		_, err := e.initDesiredPservers(ctx, e.numPservers)
105 106 107
		cancel()
		if err != nil {
			log.Warn(err)
108
			time.Sleep(retryTimeout)
109 110 111 112 113 114 115 116 117 118
			continue
		}
		break
	}
	// TODO: when implementing extending or reducing pservers, /ps_desired is
	// changed, then we need to watch /ps_desired node for events. For now, just
	// write once when init and read from it.
	// wait and set s.desired init value
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
119
		resp, err := e.client.Get(ctx, PsDesired)
120 121 122
		cancel()
		if err != nil {
			log.Errorf("getting %s error: %v", PsDesired, err)
123
			time.Sleep(retryTimeout)
124 125 126 127 128 129
			continue
		}
		if len(resp.Kvs) != 0 {
			e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
			if err != nil {
				log.Errorf("value of %s invalid %v\n", PsDesired, err)
130
				time.Sleep(retryTimeout)
131 132 133 134 135 136 137 138 139 140 141 142
				// NOTE: wait util ps_desired value change
				continue
			}
			break
		}
	}

	var pserverIdx int
	// try register pserver node on etcd
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		var err error
143
		pserverIdx, err = e.registerPserverEtcd(ctx, port)
144 145 146
		cancel()
		if err != nil {
			log.Warn(err)
147
			time.Sleep(retryTimeout)
148 149 150 151 152 153 154 155
			continue
		}
		break
	}

	return pserverIdx, nil
}

Q
Qiao Longfei 已提交
156
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
157
	return concurrency.NewSTM(e.client, func(c concurrency.STM) error {
158 159
		dsStr := c.Get(PsDesired)
		if dsStr == "" {
160
			c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease()))
161 162 163 164 165 166
		}
		return nil
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}

// registerPserverEtcd registers pserver node on etcd using transaction.
167
func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
168
	var idx int
169
	_, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error {
170 171
		registered := false
		for i := 0; i < e.desired; i++ {
Q
Qiao Longfei 已提交
172
			psKey := PsPath + strconv.Itoa(i)
173 174 175 176 177 178
			log.Debugf("checking %s", psKey)
			ps := c.Get(psKey)
			log.Debugf("got value (%s) for key: %s", ps, psKey)

			if ps == "" {
				// find the first id and write info
179
				pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
180
				c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
181
				log.Debugf("set pserver node %s with value %s", psKey, pserverAddr)
182 183 184 185 186 187
				log.Debug("register finished")
				idx = i
				registered = true
				break
			}
		}
188
		if registered {
189 190
			return nil
		}
191
		return errors.New("not registered, may due to already have enough pservers")
192 193 194 195 196 197 198 199
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))

	if err != nil {
		return 0, err
	}

	return idx, nil
}
D
dongzhihong 已提交
200

201 202 203
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
204
	resp, err := e.client.Get(ctx, key)
205 206 207 208 209 210 211 212 213 214 215 216
	cancel()
	if err != nil {
		return []byte{}, err
	}
	kvs := resp.Kvs
	if len(kvs) == 0 {
		return []byte{}, nil
	}
	v := kvs[0].Value
	return v, nil
}

D
dongzhihong 已提交
217
// PutKey put into etcd with value by key specified
218 219
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration) error {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
220
	_, err := e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
D
dongzhihong 已提交
221
	cancel()
222
	return err
D
dongzhihong 已提交
223
}
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243

// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
	var err error
	if e.sess != nil {
		err = e.sess.Close()
	}

	if e.client != nil {
		newErr := e.client.Close()
		if newErr != nil {
			if err != nil {
				log.Errorln(newErr)
			} else {
				err = newErr
			}
		}
	}
	return err
}