提交 0f3a3e98 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #3321 from helinwang/trainer_etcd

Implement trainer init parameters election with etcd
hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582 hash: 1b9b07408ca7fac27a374dc2ccd2433e4bff090484008a037df967284949a582
updated: 2017-08-03T21:46:51.744995189Z updated: 2017-08-07T23:37:48.867469328Z
imports: imports:
- name: github.com/beorn7/perks - name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
...@@ -10,7 +10,7 @@ imports: ...@@ -10,7 +10,7 @@ imports:
- name: github.com/cockroachdb/cmux - name: github.com/cockroachdb/cmux
version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92 version: 112f0506e7743d64a6eb8fedbcff13d9979bbf92
- name: github.com/coreos/etcd - name: github.com/coreos/etcd
version: c31bec0f29facff13f7c3e3d948e55dd6689ed42 version: d0d1a87aa96ae14914751d42264262cb69eda170
subpackages: subpackages:
- alarm - alarm
- auth - auth
...@@ -24,6 +24,7 @@ imports: ...@@ -24,6 +24,7 @@ imports:
- error - error
- etcdserver - etcdserver
- etcdserver/api - etcdserver/api
- etcdserver/api/etcdhttp
- etcdserver/api/v2http - etcdserver/api/v2http
- etcdserver/api/v2http/httptypes - etcdserver/api/v2http/httptypes
- etcdserver/api/v3client - etcdserver/api/v3client
...@@ -210,11 +211,6 @@ testImports: ...@@ -210,11 +211,6 @@ testImports:
version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9 version: 04cdfd42973bb9c8589fd6a731800cf222fde1a9
subpackages: subpackages:
- spew - spew
- name: github.com/docker/docker
version: b6d164e6c46d8115b146e4c3ac93784e9ef8b49e
subpackages:
- pkg/ioutils
- pkg/longpath
- name: github.com/pmezard/go-difflib - name: github.com/pmezard/go-difflib
version: d8ed2627bdf02c080bf22230dbb337003b7aba2d version: d8ed2627bdf02c080bf22230dbb337003b7aba2d
subpackages: subpackages:
......
package master_test package master_test
import ( import (
"io/ioutil"
"net/url"
"os" "os"
"strings"
"testing" "testing"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/embed" "github.com/coreos/etcd/embed"
"github.com/docker/docker/pkg/ioutils"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestNewServiceWithEtcd(t *testing.T) { func TestNewServiceWithEtcd(t *testing.T) {
// setup an embed etcd server // setup an embed etcd server
etcdDir, err := ioutils.TempDir("", "") etcdDir, err := ioutil.TempDir("", "")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
cfg := embed.NewConfig() cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg) e, err := embed.StartEtcd(cfg)
if err != nil { if err != nil {
...@@ -30,15 +36,13 @@ func TestNewServiceWithEtcd(t *testing.T) { ...@@ -30,15 +36,13 @@ func TestNewServiceWithEtcd(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
}() }()
select {
case <-e.Server.ReadyNotify():
t.Log("Server is ready!")
case <-time.After(60 * time.Second):
e.Server.Stop() // trigger a shutdown
t.Fatal("Server took too long to start!")
}
ep := []string{"127.0.0.1:2379"} <-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
ep := []string{endpoint}
masterAddr := "127.0.0.1:3306" masterAddr := "127.0.0.1:3306"
store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30) store, err := master.NewEtcdClient(ep, masterAddr, master.DefaultLockPath, master.DefaultAddrPath, master.DefaultStatePath, 30)
if err != nil { if err != nil {
......
...@@ -90,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -90,8 +90,12 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
type selector bool type selector bool
func (s selector) Select() bool { func (s selector) Select() (bool, error) {
return bool(s) return bool(s), nil
}
func (s selector) Done() error {
return nil
} }
type lister []client.Server type lister []client.Server
...@@ -114,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli ...@@ -114,11 +118,10 @@ func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_cli
} }
//export paddle_new_etcd_pserver_client //export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcdEndpoints *C.char, selected int) C.paddle_pserver_client { func paddle_new_etcd_pserver_client(etcdEndpoints *C.char) C.paddle_pserver_client {
// TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters)
addr := C.GoString(etcdEndpoints) addr := C.GoString(etcdEndpoints)
etcdClient := client.NewEtcd(addr) etcdClient := client.NewEtcd(addr)
c := client.NewClient(etcdClient, etcdClient.Desired(), selector(selected != 0)) c := client.NewClient(etcdClient, etcdClient.Desired(), etcdClient)
return add(c) return add(c)
} }
...@@ -136,7 +139,12 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) { ...@@ -136,7 +139,12 @@ func paddle_pserver_client_release(client C.paddle_pserver_client) {
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.paddle_pserver_client) C.int { func paddle_begin_init_params(client C.paddle_pserver_client) C.int {
c := get(client) c := get(client)
if selected := c.BeginInitParams(); selected { selected, err := c.BeginInitParams()
if err != nil {
panic(err)
}
if selected {
return 1 return 1
} }
return 0 return 0
......
...@@ -27,9 +27,13 @@ import ( ...@@ -27,9 +27,13 @@ import (
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers. // Selector selects if the client should initialize parameters and
// reports the initialization process done.
type Selector interface { type Selector interface {
Select() bool // Select selects if the client should initialize parameter servers.
Select() (bool, error)
// Done indicates the initialization process is done.
Done() error
} }
// Server is the identification of a parameter Server. // Server is the identification of a parameter Server.
...@@ -115,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -115,7 +119,7 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
// servers. Other trainers will be blocked until the initialization is // servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from // done, and they need to get the initialized parameters from
// parameter servers using GetParams. // parameter servers using GetParams.
func (c *Client) BeginInitParams() bool { func (c *Client) BeginInitParams() (bool, error) {
return c.sel.Select() return c.sel.Select()
} }
......
...@@ -124,8 +124,12 @@ func initEtcdClient() { ...@@ -124,8 +124,12 @@ func initEtcdClient() {
type selector bool type selector bool
func (s selector) Select() bool { func (s selector) Select() (bool, error) {
return bool(s) return bool(s), nil
}
func (s selector) Done() error {
return nil
} }
type lister []client.Server type lister []client.Server
...@@ -135,7 +139,11 @@ func (l lister) List() []client.Server { ...@@ -135,7 +139,11 @@ func (l lister) List() []client.Server {
} }
func testClient(t *testing.T, c *client.Client) { func testClient(t *testing.T, c *client.Client) {
selected := c.BeginInitParams() selected, err := c.BeginInitParams()
if err != nil {
t.Fatal(err)
}
if !selected { if !selected {
t.Fatal("should be selected.") t.Fatal("should be selected.")
} }
......
...@@ -16,53 +16,60 @@ package client ...@@ -16,53 +16,60 @@ package client
import ( import (
"context" "context"
"errors"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
"github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
const ( const (
defaultEtcdTimeout time.Duration = 5 * time.Second defaultEtcdTimeout time.Duration = 5 * time.Second
initLockPath = "/init_ps/lock"
initDonePath = "/init_ps/done"
initDoneVal = "1"
) )
// EtcdClient is used by pserver client that is a part of trainer process. // Etcd is used by pserver client that is a part of trainer process.
// TODO: // TODO:
// 1. add watcher to watch the change state of pservers) // 1. add watcher to watch the change state of pservers.
// 1. add etcd lock) type Etcd struct {
type EtcdClient struct {
client *clientv3.Client client *clientv3.Client
timeout time.Duration timeout time.Duration
endpoints []string endpoints []string
lock *concurrency.Mutex
} }
// Desired read ps desired number from etcd. // Desired read ps desired number from etcd.
func (p *EtcdClient) Desired() int { func (e *Etcd) Desired() int {
var psDesired int var psDesired int
for { for {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
resp, err := p.client.Get(ctx, pserver.PsDesired) resp, err := e.client.Get(ctx, pserver.PsDesired)
cancel() cancel()
if err != nil { if err != nil {
log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infoln("Waiting for ps desired registered ...") log.Infoln("Waiting for ps desired registered ...")
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil { if err != nil {
log.Errorf("psDesired %d invalid %v", psDesired, err) log.Errorf("psDesired %d invalid %v", psDesired, err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -73,26 +80,26 @@ func (p *EtcdClient) Desired() int { ...@@ -73,26 +80,26 @@ func (p *EtcdClient) Desired() int {
} }
// List return the pserver list read from etcd. // List return the pserver list read from etcd.
func (p *EtcdClient) List() []Server { func (e *Etcd) List() []Server {
psDesired := p.Desired() psDesired := e.Desired()
servers := make([]Server, psDesired) servers := make([]Server, psDesired)
for { for {
for i := 0; i < psDesired; i++ { for i := 0; i < psDesired; i++ {
ctx, cancel := context.WithTimeout(context.Background(), p.timeout) ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
psKey := pserver.PsPath + strconv.Itoa(i) psKey := pserver.PsPath + strconv.Itoa(i)
log.Debugf("checking %s", psKey) log.Debugf("checking %s", psKey)
resp, err := p.client.Get(ctx, psKey) resp, err := e.client.Get(ctx, psKey)
cancel() cancel()
if err != nil { if err != nil {
log.Infof("Get psKey= %s error, %v", psKey, err) log.Infof("Get psKey= %s error, %v", psKey, err)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
kvs := resp.Kvs kvs := resp.Kvs
if len(kvs) == 0 { if len(kvs) == 0 {
log.Infof("Waiting for ps addr registered ...") log.Infof("Waiting for ps addr registered ...")
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
...@@ -100,7 +107,7 @@ func (p *EtcdClient) List() []Server { ...@@ -100,7 +107,7 @@ func (p *EtcdClient) List() []Server {
// TODO(Longfei) check the ps address // TODO(Longfei) check the ps address
if psAddr == "" { if psAddr == "" {
log.Infof("Get psKey = %s, psAddr is empty", psKey) log.Infof("Get psKey = %s, psAddr is empty", psKey)
time.Sleep(p.timeout) time.Sleep(e.timeout)
continue continue
} }
log.Debugf("got value (%s) for key: %s", psAddr, psKey) log.Debugf("got value (%s) for key: %s", psAddr, psKey)
...@@ -113,7 +120,7 @@ func (p *EtcdClient) List() []Server { ...@@ -113,7 +120,7 @@ func (p *EtcdClient) List() []Server {
} }
// NewEtcd create a etcd client to return the state of pserver on etcd. // NewEtcd create a etcd client to return the state of pserver on etcd.
func NewEtcd(endpoints string) *EtcdClient { func NewEtcd(endpoints string) *Etcd {
ep := strings.Split(endpoints, ",") ep := strings.Split(endpoints, ",")
var cli *clientv3.Client var cli *clientv3.Client
var err error var err error
...@@ -130,10 +137,118 @@ func NewEtcd(endpoints string) *EtcdClient { ...@@ -130,10 +137,118 @@ func NewEtcd(endpoints string) *EtcdClient {
break break
} }
log.Infof("Connected to etcd: %s\n", endpoints) log.Infof("Connected to etcd: %s\n", endpoints)
client := &EtcdClient{ client := &Etcd{
client: cli, client: cli,
timeout: defaultEtcdTimeout, timeout: defaultEtcdTimeout,
endpoints: ep, endpoints: ep,
} }
return client return client
} }
// Select indicates if the current trainer is selected to initialize
// the pserver parameters.
func (e *Etcd) Select() (bool, error) {
sess, err := concurrency.NewSession(e.client, concurrency.WithTTL(5))
if err != nil {
return false, err
}
lock := concurrency.NewMutex(sess, initLockPath)
log.Infof("Trying to acquire lock at %s.", initLockPath)
// Do not use timeout context here, since we don't know how
// long does it take for other trainers to initialize the
// parameters.
err = lock.Lock(context.Background())
if err != nil {
return false, err
}
log.Infof("Successfully acquired lock at %s.", initLockPath)
get := clientv3.OpGet(initDonePath)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(lock.IsOwner()).Then(get).Commit()
cancel()
if err != nil {
return false, err
}
if !tresp.Succeeded {
return false, errors.New("no longer the owner of the lock")
}
resp := tresp.Responses[0].GetResponseRange()
if len(resp.Kvs) == 0 {
// Key value not set, select current trainer.
e.lock = lock
log.Infoln("Trainer selected.")
return true, nil
}
if string(resp.Kvs[0].Value) == initDoneVal {
log.Infoln("Initialization is already done.")
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
}
return false, nil
}
return false, fmt.Errorf("key %s have unexpected value: %v", initDonePath, resp.Kvs[0].Value)
}
// Done indicates the parameter initialization process is done.
func (e *Etcd) Done() error {
if e.lock == nil {
return errors.New("lock is nil, Done called unexpectedly")
}
put := clientv3.OpPut(initDonePath, initDoneVal)
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
tresp, err := e.client.Txn(ctx).If(e.lock.IsOwner()).Then(put).Commit()
cancel()
if err != nil {
return err
}
if !tresp.Succeeded {
return errors.New("no longer the owner of the lock")
}
ctx, cancel = context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err != nil {
log.Errorln(err)
} else {
e.lock = nil
}
return nil
}
// Close closes the etcd client.
func (e *Etcd) Close() error {
var err error
if e.lock != nil {
ctx, cancel := context.WithTimeout(context.Background(), e.timeout)
err = e.lock.Unlock(ctx)
cancel()
if err == nil {
e.lock = nil
}
}
cErr := e.client.Close()
if cErr != nil {
if err != nil {
log.Errorln(cErr)
return err
}
return cErr
}
return err
}
package client_test
import (
"io/ioutil"
"net/url"
"os"
"strings"
"sync"
"testing"
"github.com/PaddlePaddle/Paddle/go/pserver/client"
"github.com/coreos/etcd/embed"
)
func TestSelector(t *testing.T) {
etcdDir, err := ioutil.TempDir("", "")
if err != nil {
t.Fatal(err)
}
cfg := embed.NewConfig()
lpurl, _ := url.Parse("http://localhost:0")
lcurl, _ := url.Parse("http://localhost:0")
cfg.LPUrls = []url.URL{*lpurl}
cfg.LCUrls = []url.URL{*lcurl}
cfg.Dir = etcdDir
e, err := embed.StartEtcd(cfg)
if err != nil {
t.Fatal(err)
}
defer func() {
e.Close()
if err := os.RemoveAll(etcdDir); err != nil {
t.Fatal(err)
}
}()
<-e.Server.ReadyNotify()
port := strings.Split(e.Clients[0].Addr().String(), ":")[1]
endpoint := "127.0.0.1:" + port
var mu sync.Mutex
selectedCount := 0
var wg sync.WaitGroup
selectAndDone := func(c *client.Etcd) {
defer wg.Done()
selected, err := c.Select()
if err != nil {
panic(err)
}
if selected {
mu.Lock()
selectedCount++
mu.Unlock()
err = c.Done()
if err != nil {
t.Fatal(err)
}
}
}
c0 := client.NewEtcd(endpoint)
c1 := client.NewEtcd(endpoint)
c2 := client.NewEtcd(endpoint)
c3 := client.NewEtcd(endpoint)
wg.Add(3)
go selectAndDone(c0)
go selectAndDone(c1)
go selectAndDone(c2)
wg.Wait()
// simulate trainer crashed and restarted after the
// initialization process.
wg.Add(1)
go selectAndDone(c3)
wg.Wait()
mu.Lock()
if selectedCount != 1 {
t.Fatal("selected count wrong:", selectedCount)
}
mu.Unlock()
err = c0.Close()
if err != nil {
t.Fatal(err)
}
err = c1.Close()
if err != nil {
t.Fatal(err)
}
err = c2.Close()
if err != nil {
t.Fatal(err)
}
err = c3.Close()
if err != nil {
t.Fatal(err)
}
}
...@@ -50,8 +50,8 @@ void NewRemoteParameterUpdater::init( ...@@ -50,8 +50,8 @@ void NewRemoteParameterUpdater::init(
// create parameter server client. // create parameter server client.
if (useEtcd_) { if (useEtcd_) {
parameterClient_ = paddle_new_etcd_pserver_client( parameterClient_ =
(char *)pserverSpec_.c_str(), FLAGS_trainer_id == 0); paddle_new_etcd_pserver_client((char *)pserverSpec_.c_str());
} else { } else {
parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(), parameterClient_ = paddle_new_pserver_client((char *)pserverSpec_.c_str(),
FLAGS_trainer_id == 0); FLAGS_trainer_id == 0);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册