提交 badcdfe1 编写于 作者: W wuyi05

pserver etcd registration

上级 09f34c4b
......@@ -5,18 +5,34 @@ import (
"net/http"
"net/rpc"
"strconv"
"time"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
)
func main() {
port := flag.Int("port", 0, "port of the pserver")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
logLevel := flag.String("log-level", "info", "log level, one of debug")
flag.Parse()
s := pserver.NewService()
err := rpc.Register(s)
level, err := log.ParseLevel(*logLevel)
if err != nil {
panic(err)
}
log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, timeout)
if err != nil {
panic(err)
}
err = rpc.Register(s)
if err != nil {
panic(err)
}
......
......@@ -7,6 +7,7 @@ import (
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
......@@ -30,9 +31,12 @@ func init() {
port[i] = p
go func(l net.Listener) {
s := pserver.NewService()
s, err := pserver.NewService("", time.Second*5)
if err != nil {
panic(err)
}
server := rpc.NewServer()
err := server.Register(s)
err = server.Register(s)
if err != nil {
panic(err)
}
......
package pserver
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
)
// ElementType is the type of elements of a Parameter.
......@@ -47,14 +56,113 @@ type Service struct {
mu sync.Mutex
opt *optimizer
paramMap map[string]Parameter
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
}
// NewService creates a new service.
func NewService() *Service {
func NewService(endpoints string, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{})
return s
s.etcdEndpoints = endpoints
s.etcdTimeout = timeout
var err error
s.externalIP, err = utils.GetExternalIP()
if err != nil {
return nil, err
}
if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, "/ps_desired")
cancel()
if err != nil {
log.Errorf("getting /ps_desired error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
for _, ev := range resp.Kvs {
log.Debugf("key: %s, value: %s", ev.Key, ev.Value)
if string(ev.Key) == "/ps_desired" {
s.desired, err = strconv.Atoi(string(ev.Value))
if err != nil {
log.Errorf("value of /ps_desired invalid %v\n", err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
}
}
break
}
s.registerPserverEtcd()
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd() (*clientv3.TxnResponse, error) {
return concurrency.NewSTMRepeatable(context.TODO(), s.etcdClient, func(c concurrency.STM) error {
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
if ps == "" {
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
ch, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
// FIXME: does this really needed?
go func(ch <-chan *clientv3.LeaseKeepAliveResponse) {
ka := <-ch
log.Debugf("keepalive: %d\n", ka.TTL)
}(ch)
break
}
}
log.Debug("register finished")
return nil
})
}
// InitParam initializes a parameter.
......
......@@ -10,12 +10,15 @@ import (
)
func TestFull(t *testing.T) {
s := pserver.NewService()
s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
var p pserver.Parameter
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil {
t.FailNow()
}
......@@ -72,8 +75,11 @@ func TestFull(t *testing.T) {
}
func TestMultipleInit(t *testing.T) {
s := pserver.NewService()
err := s.FinishInitParams(0, nil)
s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
err = s.FinishInitParams(0, nil)
if err != nil {
t.FailNow()
}
......@@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) {
}
func TestUninitialized(t *testing.T) {
s := pserver.NewService()
err := s.SendGrad(pserver.Gradient{}, nil)
s, err := pserver.NewService("", time.Second*5)
err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized {
t.FailNow()
}
}
func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService()
s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
ch := make(chan struct{}, 2)
errCh := make(chan error, 2)
var wg sync.WaitGroup
......@@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil {
t.FailNow()
}
......
package utils
import (
"errors"
"net"
)
// GetExternalIP returns the ip address of local network interface, not the
// loopback device.
func GetExternalIP() (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip.String(), nil
}
}
return "", errors.New("are you connected to the network?")
}
package utils
import "testing"
func TestGetIP(t *testing.T) {
_, err := GetExternalIP()
if err != nil {
t.Errorf("GetExternalIP returns error : %v\n", err)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册