etcd_client.go 6.9 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17 18 19 20 21 22 23 24 25 26
package pserver

import (
	"context"
	"errors"
	"strconv"
	"strings"
	"time"

	"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
	"github.com/coreos/etcd/clientv3"
	"github.com/coreos/etcd/clientv3/concurrency"
27
	log "github.com/inconshreveable/log15"
28 29
)

Q
Qiao Longfei 已提交
30 31 32
const (
	// PsDesired is etcd path for store desired pserver count
	PsDesired = "/ps_desired"
33
	// PsPath is the base dir for pserver to store their addr
Q
Qiao Longfei 已提交
34
	PsPath = "/ps/"
D
dongzhihong 已提交
35 36
	// PsCheckpoint is the etcd path for store checkpoints information
	PsCheckpoint = "/checkpoints/"
37 38

	retryTimeout = 5 * time.Second
Q
Qiao Longfei 已提交
39 40
)

41 42 43
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
44 45 46 47 48 49
	numPservers int
	endpoints   string
	client      *clientv3.Client
	sess        *concurrency.Session
	dialTimeout time.Duration
	ttlSec      int
50 51 52 53 54 55 56 57
	// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
	externalIP string
	// desired number of pservers in the job.
	// assume desired will not change during one training job.
	desired int
}

// NewEtcdClient creates an EtcdClient
58
func NewEtcdClient(endpoints string, numPservers int, dialtimeout time.Duration, ttlSec int) *EtcdClient {
59
	return &EtcdClient{
60 61 62 63
		dialTimeout: dialtimeout,
		ttlSec:      ttlSec,
		numPservers: numPservers,
		endpoints:   endpoints,
64 65 66 67 68 69
	}
}

// Register registers the pserver on etcd
//
// Register returns the index of the current pserver.
70
func (e *EtcdClient) Register(port int) (int, error) {
71 72 73 74 75 76 77
	var err error
	e.externalIP, err = networkhelper.GetExternalIP()
	if err != nil {
		return 0, err
	}

	// initialize connection to etcd.
78
	ep := strings.Split(e.endpoints, ",")
79 80 81
	for {
		cli, err := clientv3.New(clientv3.Config{
			Endpoints:   ep,
82
			DialTimeout: e.dialTimeout,
83 84
		})
		if err != nil {
85
			log.Error("connect to etcd error", log.Ctx{"error": err})
86 87 88 89 90 91
			time.Sleep(retryTimeout)
			continue
		}
		e.client = cli
		sess, err := concurrency.NewSession(cli, concurrency.WithTTL(e.ttlSec))
		if err != nil {
92
			log.Error("create etcd session error", log.Ctx{"error": err})
93
			time.Sleep(retryTimeout)
94 95
			continue
		}
96
		e.sess = sess
97
		log.Debug("connected to etcd", log.Ctx{"endpoint": e.endpoints})
98 99 100 101 102 103
		break
	}
	// init /ps_desired using transaction, for multiple pservers may want to write
	// it at the same time.
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
Q
Qiao Longfei 已提交
104
		_, err := e.initDesiredPservers(ctx, e.numPservers)
105 106
		cancel()
		if err != nil {
107
			log.Warn("pserver init error", log.Ctx{"error": err, "num pservers": e.numPservers})
108
			time.Sleep(retryTimeout)
109 110 111 112 113 114 115 116 117 118
			continue
		}
		break
	}
	// TODO: when implementing extending or reducing pservers, /ps_desired is
	// changed, then we need to watch /ps_desired node for events. For now, just
	// write once when init and read from it.
	// wait and set s.desired init value
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
119
		resp, err := e.client.Get(ctx, PsDesired)
120 121
		cancel()
		if err != nil {
122
			log.Error("get etcd key error", log.Ctx{"key": PsDesired, "error": err})
123
			time.Sleep(retryTimeout)
124 125 126 127 128
			continue
		}
		if len(resp.Kvs) != 0 {
			e.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
			if err != nil {
129 130 131 132
				log.Error(
					"psDesired atoi error",
					log.Ctx{"error": err, "value": string(resp.Kvs[0].Value)},
				)
133
				time.Sleep(retryTimeout)
134 135 136 137 138 139 140 141 142 143 144 145
				// NOTE: wait util ps_desired value change
				continue
			}
			break
		}
	}

	var pserverIdx int
	// try register pserver node on etcd
	for {
		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
		var err error
146
		pserverIdx, err = e.registerPserverEtcd(ctx, port)
147 148
		cancel()
		if err != nil {
149
			log.Warn("register pserver on etcd error", log.Ctx{"error": err})
150
			time.Sleep(retryTimeout)
151 152 153 154 155 156 157 158
			continue
		}
		break
	}

	return pserverIdx, nil
}

Q
Qiao Longfei 已提交
159
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
160
	return concurrency.NewSTM(e.client, func(c concurrency.STM) error {
161 162
		dsStr := c.Get(PsDesired)
		if dsStr == "" {
163
			c.Put(PsDesired, strconv.Itoa(numPservers), clientv3.WithLease(e.sess.Lease()))
164 165 166 167 168 169
		}
		return nil
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
}

// registerPserverEtcd registers pserver node on etcd using transaction.
170
func (e *EtcdClient) registerPserverEtcd(ctx context.Context, port int) (int, error) {
171
	var idx int
172
	_, err := concurrency.NewSTM(e.client, func(c concurrency.STM) error {
173 174
		registered := false
		for i := 0; i < e.desired; i++ {
Q
Qiao Longfei 已提交
175
			psKey := PsPath + strconv.Itoa(i)
176
			ps := c.Get(psKey)
177 178 179 180
			log.Debug(
				"register pserver got value",
				log.Ctx{"value": ps, "key": psKey},
			)
181 182 183

			if ps == "" {
				// find the first id and write info
184
				pserverAddr := e.externalIP + ":" + strconv.Itoa(port)
185
				c.Put(psKey, pserverAddr, clientv3.WithLease(e.sess.Lease()))
186
				log.Debug("register finished", log.Ctx{"key": psKey, "value": pserverAddr})
187 188 189 190 191
				idx = i
				registered = true
				break
			}
		}
192
		if registered {
193 194
			return nil
		}
195
		return errors.New("not registered, may due to already have enough pservers")
196 197 198 199 200 201 202 203
	}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))

	if err != nil {
		return 0, err
	}

	return idx, nil
}
D
dongzhihong 已提交
204

205 206 207
// GetKey gets the value by the specified key
func (e *EtcdClient) GetKey(key string, timeout time.Duration) ([]byte, error) {
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
208
	resp, err := e.client.Get(ctx, key)
209 210 211 212
	cancel()
	if err != nil {
		return []byte{}, err
	}
H
Helin Wang 已提交
213

214 215 216 217 218 219 220 221
	kvs := resp.Kvs
	if len(kvs) == 0 {
		return []byte{}, nil
	}
	v := kvs[0].Value
	return v, nil
}

D
dongzhihong 已提交
222
// PutKey put into etcd with value by key specified
H
Helin Wang 已提交
223
func (e *EtcdClient) PutKey(key string, value []byte, timeout time.Duration, withLease bool) error {
224
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
H
Helin Wang 已提交
225 226 227 228 229 230
	var err error
	if withLease {
		_, err = e.client.Put(ctx, key, string(value), clientv3.WithLease(e.sess.Lease()))
	} else {
		_, err = e.client.Put(ctx, key, string(value))
	}
D
dongzhihong 已提交
231
	cancel()
232
	return err
D
dongzhihong 已提交
233
}
234 235 236 237 238 239 240 241 242 243 244 245

// Shutdown shuts down the etcd client gracefully.
func (e *EtcdClient) Shutdown() error {
	var err error
	if e.sess != nil {
		err = e.sess.Close()
	}

	if e.client != nil {
		newErr := e.client.Close()
		if newErr != nil {
			if err != nil {
246
				log.Error("shutdown error", log.Ctx{"error": newErr})
247 248 249 250 251 252 253
			} else {
				err = newErr
			}
		}
	}
	return err
}