提交 9af8d86b 编写于 作者: Y Yancey 提交者: GitHub

Trainer library discover master by etcd (#2551)

* add trainer library

* modifty file name

* move trainer to master client

* update

* update

* modify monitor master to receive a chan

* update

* use etcd client from etcd_client.go

* update

* update

* remove etcd client without lock

* update

* update the comment

* update commonts
上级 9b1c04f7
......@@ -13,10 +13,13 @@ typedef int paddle_master_client;
import "C"
import (
"strings"
"sync"
"time"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/coreos/etcd/clientv3"
log "github.com/sirupsen/logrus"
)
......@@ -48,16 +51,33 @@ func remove(client C.paddle_master_client) *master.Client {
return h
}
type addresser string
func (a addresser) Address() string {
return string(a)
//export paddle_new_etcd_master_client
func paddle_new_etcd_master_client(etcdEndpoints *C.char, timeout int, bufSize int) C.paddle_master_client {
p := C.GoString(etcdEndpoints)
cli, err := clientv3.New(clientv3.Config{
Endpoints: strings.Split(p, ","),
DialTimeout: time.Second * time.Duration(timeout),
})
if err != nil {
panic(err)
}
ch := make(chan string, 1)
a, err := master.GetKey(cli, master.DefaultAddrPath, timeout)
if err != nil {
panic(err)
}
ch <- a
go master.WatchKey(cli, master.DefaultAddrPath, ch)
c := master.NewClient(ch, bufSize)
return add(c)
}
//export paddle_new_master_client
func paddle_new_master_client(addr *C.char, bufSize int) C.paddle_master_client {
a := C.GoString(addr)
c := master.NewClient(addresser(a), bufSize)
ch := make(chan string, 1)
ch <- a
c := master.NewClient(ch, bufSize)
return add(c)
}
......
......@@ -2,18 +2,12 @@ package master
import (
"os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
log "github.com/sirupsen/logrus"
)
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
......@@ -24,11 +18,11 @@ type Client struct {
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
func NewClient(addr Addresser, bufSize int) *Client {
func NewClient(addrCh <-chan string, bufSize int) *Client {
c := &Client{}
c.conn = connection.New()
c.ch = make(chan []byte, bufSize)
go c.monitorMaster(addr)
go c.monitorMaster(addrCh)
go c.getRecords()
return c
}
......@@ -72,12 +66,10 @@ func (c *Client) getRecords() {
}
}
func (c *Client) monitorMaster(addr Addresser) {
func (c *Client) monitorMaster(addrCh <-chan string) {
lastMaster := ""
monitor := func() {
// get the lastest address of the master server,
for curMaster := range addrCh {
// connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster {
if curMaster == "" {
err := c.conn.Close()
......@@ -94,18 +86,10 @@ func (c *Client) monitorMaster(addr Addresser) {
// to retry next time.
curMaster = lastMaster
}
}
}
lastMaster = curMaster
}
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
}
// SetDataset set dataset for the master server to dispatch.
......
......@@ -26,12 +26,6 @@ func init() {
log.SetLevel(log.ErrorLevel)
}
type TestAddresser string
func (a TestAddresser) Address() string {
return string(a)
}
func TestGetFinishTask(t *testing.T) {
const path = "/tmp/master_client_test_0"
......@@ -45,7 +39,6 @@ func TestGetFinishTask(t *testing.T) {
if err != nil {
panic(err)
}
go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
......@@ -82,9 +75,11 @@ func TestGetFinishTask(t *testing.T) {
// Manually intialize client to avoid calling c.getRecords()
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(TestAddresser(fmt.Sprintf(":%d", p)))
addr := fmt.Sprintf(":%d", p)
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
c.SetDataset([]string{path})
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
......
......@@ -20,7 +20,6 @@ func TestNextRecord(t *testing.T) {
path = "/tmp/master_client_TestFull"
total = 50
)
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
......@@ -31,7 +30,6 @@ func TestNextRecord(t *testing.T) {
if err != nil {
panic(err)
}
go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
if err != nil {
......@@ -63,10 +61,10 @@ func TestNextRecord(t *testing.T) {
}
w.Close()
f.Close()
c := master.NewClient(master.TestAddresser(fmt.Sprintf(":%d", p)), 10)
curAddr := make(chan string, 1)
curAddr <- fmt.Sprintf(":%d", p)
c := master.NewClient(curAddr, 10)
c.SetDataset([]string{path})
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
......
......@@ -142,3 +142,31 @@ func (e *EtcdClient) Load() ([]byte, error) {
state := kvs[0].Value
return state, nil
}
// GetKey gets the value by the specify key.
func GetKey(c *clientv3.Client, key string, timeout int) (string, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(timeout))
resp, err := c.Get(ctx, key)
cancel()
if err != nil {
return "", err
}
kvs := resp.Kvs
if len(kvs) == 0 {
return "", nil
}
v := kvs[0].Value
return string(v), nil
}
// WatchKey watches the specify key and send to valChan if there is some event.
func WatchKey(c *clientv3.Client, key string, valChan chan<- string) {
rch := c.Watch(context.Background(), key)
for wresp := range rch {
for _, ev := range wresp.Events {
// if received event is DELETE, the value will be an empty string
log.Infof("received event %s, %q : %q\n", ev.Type, ev.Kv.Key, ev.Kv.Value)
valChan <- string(ev.Kv.Value)
}
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册