提交 9af8d86b 编写于 作者: Y Yancey 提交者: GitHub

Trainer library discover master by etcd (#2551)

* add trainer library

* modifty file name

* move trainer to master client

* update

* update

* modify monitor master to receive a chan

* update

* use etcd client from etcd_client.go

* update

* update

* remove etcd client without lock

* update

* update the comment

* update commonts
上级 9b1c04f7
...@@ -13,10 +13,13 @@ typedef int paddle_master_client; ...@@ -13,10 +13,13 @@ typedef int paddle_master_client;
import "C" import "C"
import ( import (
"strings"
"sync" "sync"
"time"
"unsafe" "unsafe"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
...@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client { ...@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client {
return h return h
} }
type addresser string //export paddle_new_etcd_master_client
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
func (a addresser) Address() string { p := C.GoString(etcdEndpoints)
return string(a) cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
} }
//export paddle_new_master_client //export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client { func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr) a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize) ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
return add(c) return add(c)
} }
......
...@@ -2,18 +2,12 @@ package master ...@@ -2,18 +2,12 @@ package master
import ( import (
"os" "os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection" "github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio" "github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server. // Client is the client of the master server.
type Client struct { type Client struct {
conn *connection.Conn conn *connection.Conn
...@@ -24,11 +18,11 @@ type Client struct { ...@@ -24,11 +18,11 @@ type Client struct {
// //
// bufSize is the record buffer size. NextRecord will read from this // bufSize is the record buffer size. NextRecord will read from this
// buffer. // buffer.
func NewClient(addr Addresser, bufSize int) *Client { func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
c.ch = make(chan []byte, bufSize) c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr) go c.monitorMaster(addrCh)
go c.getRecords() go c.getRecords()
return c return c
} }
...@@ -72,12 +66,10 @@ func (c *Client) getRecords() { ...@@ -72,12 +66,10 @@ func (c *Client) getRecords() {
} }
} }
func (c *Client) monitorMaster(addr Addresser) { func (c *Client) monitorMaster(addrCh <-chan string) {
lastMaster := "" lastMaster := ""
monitor := func() { for curMaster := range addrCh {
// get the lastest address of the master server,
// connect to the new address once address changed. // connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster { if curMaster != lastMaster {
if curMaster == "" { if curMaster == "" {
err := c.conn.Close() err := c.conn.Close()
...@@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) { ...@@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) {
// to retry next time. // to retry next time.
curMaster = lastMaster curMaster = lastMaster
} }
} }
} }
lastMaster = curMaster lastMaster = curMaster
} }
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
} }
// SetDataset set dataset for the master server to dispatch. // SetDataset set dataset for the master server to dispatch.
......
...@@ -26,12 +26,6 @@ func init() { ...@@ -26,12 +26,6 @@ func init() {
log.SetLevel(log.ErrorLevel) log.SetLevel(log.ErrorLevel)
} }
type TestAddresser string
func (a TestAddresser) Address() string {
return string(a)
}
func TestGetFinishTask(t *testing.T) { func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0" const path = "/tmp/master_client_test_0"
...@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) { ...@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil { if err != nil {
...@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) { ...@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) {
// Manually intialize client to avoid calling c.getRecords() // Manually intialize client to avoid calling c.getRecords()
c := &Client{} c := &Client{}
c.conn = connection.New() c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p))) addr := fmt.Sprintf(":%d", p)
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
c.SetDataset([]string{path}) c.SetDataset([]string{path})
checkOnePass := func(i int) { checkOnePass := func(i int) {
var tasks []Task var tasks []Task
for idx := 0; idx < totalTask; idx++ { for idx := 0; idx < totalTask; idx++ {
......
...@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) { ...@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) {
path = "/tmp/master_client_TestFull" path = "/tmp/master_client_TestFull"
total = 50 total = 50
) )
l, err := net.Listen("tcp", ":0") l, err := net.Listen("tcp", ":0")
if err != nil { if err != nil {
panic(err) panic(err)
...@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) { ...@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) {
if err != nil { if err != nil {
panic(err) panic(err)
} }
go func(l net.Listener) { go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
if err != nil { if err != nil {
...@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) { ...@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) {
} }
w.Close() w.Close()
f.Close() f.Close()
curAddr := make(chan string, 1)
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10) curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path}) c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ { for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool) received := make(map[byte]bool)
for i := 0; i < total; i++ { for i := 0; i < total; i++ {
......
...@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) { ...@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) {
state := kvs[0].Value state := kvs[0].Value
return state, nil return state, nil
} }
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
return "", err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return "", nil
}
v := kvs[0].Value
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value)
valChan <- string(ev.Kv.Value)
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册