提交 9045063b 编写于 作者: Q Qiao Longfei 提交者: GitHub

pserver etcd client (#2559)

* init etcd cclient

* add etcd

* add etcd.go

* fix compile problem

* move code to etcd.go

* add etcd_lister.go for pserver client

* add etcd_client_test.go

* merge etcd_client_test and client_test

* refine client_test.go

* refine code

* format code

* add TODO and use interface instead of struct

* fix typo of initDesiredPservers

* optimize dir structure of go/pserver/client

* add a flag to config index for pserver

* follow comment

* fix path

* optimize code

* remove err in pserver NewEtcd

* restore comment about /ps_desired
上级 b688840f
......@@ -113,7 +113,7 @@ include(coveralls) # set code coverage
include_directories("${PROJ_ROOT}")
include_directories("${PROJ_ROOT}/paddle/cuda/include")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient")
include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c")
include_directories(${Boost_INCLUDE_DIRS})
set(EXTERNAL_LIBS
......
......@@ -13,7 +13,7 @@
# limitations under the License.
#
add_subdirectory(pserver/cclient)
add_subdirectory(pserver/client/c)
add_subdirectory(cmd/pserver)
add_subdirectory(cmd/master)
add_subdirectory(master/c)
......@@ -15,6 +15,7 @@ import (
func main() {
port := flag.Int("port", 0, "port of the pserver")
index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
......@@ -29,12 +30,17 @@ func main() {
}
log.SetLevel(level)
var idx int
if *index >= 0 {
idx = *index
} else {
timeout := time.Second * time.Duration((*etcdTimeout))
e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout)
idx, err := e.Register()
idx, err = e.Register()
if err != nil {
panic(err)
}
}
s, err := pserver.NewService(idx)
if err != nil {
......
......@@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat
lock := concurrency.NewMutex(sess, lockPath)
// It's fine for the lock to get stuck, in this case we have
// multiple master servers running (only configured to have
// one master running, but split-brain problem may cuase
// one master running, but split-brain problem may cause
// multiple master servers running), and the cluster management
// software will kill one of them.
log.Debugf("Trying to acquire lock at %s.", lockPath)
......@@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error {
// We lost the master lock and can not acquire
// it back, it means some other master is
// already started. We don't want cluster
// managment system to kill the master server
// management system to kill the master server
// who is holding the lock and running
// correctly. So the most feasible solution is
// to kill current master server. The current
......
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
go_library(paddle_pserver_cclient STATIC)
go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer)
if(WITH_TESTING)
add_subdirectory(test)
endif()
......@@ -30,15 +30,16 @@ import (
"unsafe"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
log "github.com/sirupsen/logrus"
)
var nullPtr = unsafe.Pointer(uintptr(0))
var mu sync.Mutex
var handleMap = make(map[C.paddle_pserver_client]*pserver.Client)
var handleMap = make(map[C.paddle_pserver_client]*client.Client)
var curHandle C.paddle_pserver_client
func add(c *pserver.Client) C.paddle_pserver_client {
func add(c *client.Client) C.paddle_pserver_client {
mu.Lock()
defer mu.Unlock()
client := curHandle
......@@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client {
return client
}
func get(client C.paddle_pserver_client) *pserver.Client {
func get(client C.paddle_pserver_client) *client.Client {
mu.Lock()
defer mu.Unlock()
return handleMap[client]
}
func remove(client C.paddle_pserver_client) *pserver.Client {
func remove(client C.paddle_pserver_client) *client.Client {
mu.Lock()
defer mu.Unlock()
h := handleMap[client]
......@@ -80,9 +81,9 @@ func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
type lister []client.Server
func (l lister) List() []pserver.Server {
func (l lister) List() []client.Server {
return l
}
......@@ -90,19 +91,22 @@ func (l lister) List() []pserver.Server {
func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client {
a := C.GoString(addrs)
as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as))
servers := make([]client.Server, len(as))
for i := range as {
servers[i].Index = i
servers[i].Addr = as[i]
}
c := pserver.NewClient(lister(servers), len(as), selector(selected != 0))
c := client.NewClient(lister(servers), len(as), selector(selected != 0))
return add(c)
}
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client {
// TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.")
func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
addr := C.GoString(etcd_endpoints)
etcd_client := client.NewEtcd(addr)
c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0))
return add(c)
}
//export paddle_pserver_client_release
......
package pserver
package client
import (
"errors"
......@@ -7,6 +7,7 @@ import (
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
)
......@@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool {
}
// InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error {
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
}
......@@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error {
// SendGrads sends gradients to parameter servers for updating
// parameters.
func (c *Client) SendGrads(grads []Gradient) error {
func (c *Client) SendGrads(grads []pserver.Gradient) error {
if len(grads) == 0 {
return errors.New("no gradient received")
}
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g Gradient) {
go func(g pserver.Gradient) {
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
errCh <- err
}(g)
......@@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
type result struct {
idx int
param Parameter
param pserver.Parameter
err error
}
......@@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) {
}
// GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) {
func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
rCh := make(chan result, len(names))
for idx, name := range names {
go func(name string, idx int) {
var parameter Parameter
var parameter pserver.Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, param: parameter, err: err}
}(name, idx)
......@@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) {
}
sort.Sort(rs)
ps := make([]Parameter, len(rs))
ps := make([]pserver.Parameter, len(rs))
for i := range rs {
ps[i] = rs[i].param
}
......
package pserver_test
package client_test
import (
"context"
"io/ioutil"
"net"
"net/http"
......@@ -8,15 +9,25 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
const numPserver = 10
const (
numPserver = 10
etcdEndpoints = "127.0.0.1:2379"
timeout = 2 * time.Second
)
var port [numPserver]int
var pserverClientPorts [numPserver]int
func init() {
// this function init pserver client and return their ports in an array.
func initClient() [numPserver]int {
var ports [numPserver]int
for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0")
if err != nil {
......@@ -28,7 +39,7 @@ func init() {
if err != nil {
panic(err)
}
port[i] = p
ports[i] = p
go func(l net.Listener) {
s, err := pserver.NewService(0)
......@@ -49,6 +60,31 @@ func init() {
}
}(l)
}
return ports
}
func initNativeClient() {
pserverClientPorts = initClient()
}
func initEtcdClient() {
client, err := clientv3.New(clientv3.Config{
Endpoints: []string{etcdEndpoints},
DialTimeout: time.Second * time.Duration(1),
})
if err != nil {
log.Errorf("err %v", err)
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
client.Delete(ctx, pserver.PsDesired)
client.Delete(ctx, pserver.PsPath)
client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver))
ports := initClient()
for i := 0; i < numPserver; i++ {
client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i]))
}
cancel()
client.Close()
}
type selector bool
......@@ -57,25 +93,20 @@ func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
type lister []client.Server
func (l lister) List() []pserver.Server {
func (l lister) List() []client.Server {
return l
}
func TestClientFull(t *testing.T) {
servers := make([]pserver.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
}
c := pserver.NewClient(lister(servers), len(servers), selector(true))
func ClientTest(t *testing.T, c *client.Client) {
selected := c.BeginInitParams()
if !selected {
t.Fatal("should be selected.")
}
const numParameter = 100
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
......@@ -129,3 +160,21 @@ func TestClientFull(t *testing.T) {
}
}
}
func TestNativeClient(t *testing.T) {
initNativeClient()
servers := make([]client.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])}
}
c1 := client.NewClient(lister(servers), len(servers), selector(true))
ClientTest(t, c1)
}
// TODO: tmperary disable etcdClient test for dependency of etcd)
func EtcdClient(t *testing.T) {
initEtcdClient()
etcd_client := client.NewEtcd(etcdEndpoints)
c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true))
ClientTest(t, c2)
}
package client
import (
"context"
"strconv"
"strings"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
const (
DefaultEtcdTimeout time.Duration = 5 * time.Second
)
// EtcdClient is used by pserver client that is a part of trainer process.
// TODO:
// 1. add watcher to watch the change state of pservers)
// 1. add etcd lock)
type EtcdClient struct {
client *clientv3.Client
timeout time.Duration
endpoints []string
}
// Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int {
var psDesired int
for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired)
cancel()
if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout)
continue
}
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("psDesired %s invalid %v", psDesired, err)
time.Sleep(p.timeout)
continue
}
log.Debugf("Get psDesired number: %d", psDesired)
break
}
return psDesired
}
// List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server {
psDesired := p.Desired()
servers := make([]Server, psDesired)
for {
for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout)
cancel()
psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey)
if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout)
continue
}
kvs := resp.Kvs
if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout)
continue
}
psAddr := string(resp.Kvs[0].Value)
// TODO(Longfei) check the ps address
if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout)
continue
}
log.Infof("got value (%s) for key: %s", psAddr, psKey)
servers[i].Index = i
servers[i].Addr = psAddr
}
break
}
return servers
}
// NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient {
ep := strings.Split(endpoints, ",")
var cli *clientv3.Client
var err error
for {
cli, err = clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: DefaultEtcdTimeout,
})
if err != nil {
log.Errorf("Init etcd connection failed: %v", err)
time.Sleep(DefaultEtcdTimeout)
continue
}
break
}
log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{
client: cli,
timeout: DefaultEtcdTimeout,
endpoints: ep,
}
return client
}
......@@ -13,6 +13,13 @@ import (
log "github.com/sirupsen/logrus"
)
const (
// PsDesired is etcd path for store desired pserver count
PsDesired = "/ps_desired"
// PsAddr is the base dir for pserver to store their addr
PsPath = "/ps/"
)
// EtcdClient is the etcd client that the pserver uses for fault
// tolerance, service registry and coordination.
type EtcdClient struct {
......@@ -68,7 +75,7 @@ func (e *EtcdClient) Register() (int, error) {
// it at the same time.
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := e.initDesiredPsercers(ctx, e.numPservers)
_, err := e.initDesiredPservers(ctx, e.numPservers)
cancel()
if err != nil {
log.Warn(err)
......@@ -120,7 +127,7 @@ func (e *EtcdClient) Register() (int, error) {
return pserverIdx, nil
}
func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
dsStr := c.Get(PsDesired)
if dsStr == "" {
......@@ -136,7 +143,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) {
_, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < e.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
psKey := PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
......
......@@ -2,7 +2,7 @@ package pserver
// #cgo CFLAGS: -I ../../
// //FIXME: ldflags contain "build" path
// #cgo LDFLAGS: ../../build/go/pserver/cclient/libpaddle_go_optimizer.a -lstdc++
// #cgo LDFLAGS: ../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++
// #include "paddle/optimizer/optimizer.h"
// #include <stdlib.h>
// #include <string.h>
......
......@@ -11,7 +11,7 @@ func TestOptimizerCreateRelease(t *testing.T) {
ElementType: Int32,
}
p.Content = []byte{1, 3}
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb")
if err != nil {
t.Fatalf("read optimizer proto failed")
}
......
......@@ -24,9 +24,6 @@ const (
Float64
)
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"
// Parameter is a piece of data to sync with the parameter server.
type Parameter struct {
Name string
......
......@@ -10,6 +10,10 @@ import (
"github.com/PaddlePaddle/Paddle/go/pserver"
)
const (
OptimizerConfig = "./client/c/test/testdata/optimizer.pb"
)
func TestServiceFull(t *testing.T) {
s, err := pserver.NewService(0)
if err != nil {
......@@ -19,7 +23,7 @@ func TestServiceFull(t *testing.T) {
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil {
t.Fatalf("read optimizer proto failed")
}
......@@ -149,7 +153,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb")
config, err := ioutil.ReadFile(OptimizerConfig)
if err != nil {
t.Fatalf("read optimizer proto failed")
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册