提交 94bfe2b6 编写于 作者: 武毅 提交者: GitHub

Merge pull request #2544 from typhoonzero/pserver_etcd

Pserver etcd registration
...@@ -5,18 +5,35 @@ import ( ...@@ -5,18 +5,35 @@ import (
"net/http" "net/http"
"net/rpc" "net/rpc"
"strconv" "strconv"
"time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
log "github.com/sirupsen/logrus"
) )
func main() { func main() {
port := flag.Int("port", 0, "port of the pserver") port := flag.Int("port", 0, "port of the pserver")
etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379",
"comma separated endpoint string for pserver to connect to etcd")
etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls")
logLevel := flag.String("log-level", "info",
"log level, possible values: debug, info, warning, error, fatal, panic")
flag.Parse() flag.Parse()
s := pserver.NewService() level, err := log.ParseLevel(*logLevel)
err := rpc.Register(s) if err != nil {
panic(err)
}
log.SetLevel(level)
timeout := time.Second * time.Duration((*etcdTimeout))
s, err := pserver.NewService(*etcdEndpoint, timeout)
if err != nil {
panic(err)
}
err = rpc.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
...@@ -27,7 +44,9 @@ func main() { ...@@ -27,7 +44,9 @@ func main() {
panic(err) panic(err)
} }
log.Infof("start pserver at port %d", *port)
err = http.Serve(l, nil) err = http.Serve(l, nil)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
...@@ -7,6 +7,7 @@ import ( ...@@ -7,6 +7,7 @@ import (
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
...@@ -30,9 +31,12 @@ func init() { ...@@ -30,9 +31,12 @@ func init() {
port[i] = p port[i] = p
go func(l net.Listener) { go func(l net.Listener) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
panic(err)
}
server := rpc.NewServer() server := rpc.NewServer()
err := server.Register(s) err = server.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
} }
......
package pserver package pserver
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings"
"sync" "sync"
"time"
"github.com/PaddlePaddle/Paddle/go/utils/networkhelper"
"github.com/coreos/etcd/clientv3"
"github.com/coreos/etcd/clientv3/concurrency"
log "github.com/sirupsen/logrus"
) )
// ElementType is the type of elements of a Parameter. // ElementType is the type of elements of a Parameter.
...@@ -24,6 +33,9 @@ const ( ...@@ -24,6 +33,9 @@ const (
Float64 Float64
) )
// PsDesired is etcd path for store desired pserver count
const PsDesired = "/ps_desired"
// Parameter is a piece of data to sync with the parameter server. // Parameter is a piece of data to sync with the parameter server.
type Parameter struct { type Parameter struct {
Name string Name string
...@@ -47,14 +59,121 @@ type Service struct { ...@@ -47,14 +59,121 @@ type Service struct {
mu sync.Mutex mu sync.Mutex
opt *optimizer opt *optimizer
paramMap map[string]Parameter paramMap map[string]Parameter
etcdEndpoints string
etcdClient *clientv3.Client
// etcdTimeout is also used as retry intervals.
etcdTimeout time.Duration
// desired number of pservers in the job.
// assume desired will not change during one training job.
desired int
// FIXME: ensure GetExternalIP gets the correct ip for trainers to connect.
externalIP string
} }
// NewService creates a new service. // NewService creates a new service, will bypass etcd registration if no
func NewService() *Service { // endpoints specified.
func NewService(endpoints string, timeout time.Duration) (*Service, error) {
s := &Service{opt: newOptimizer(sgd, 0.005)} s := &Service{opt: newOptimizer(sgd, 0.005)}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
return s s.etcdEndpoints = endpoints
s.etcdTimeout = timeout
var err error
s.externalIP, err = networkhelper.GetExternalIP()
if err != nil {
return nil, err
}
if endpoints != "" {
// initialize connection to etcd, try
ep := strings.Split(s.etcdEndpoints, ",")
for {
cli, err := clientv3.New(clientv3.Config{
Endpoints: ep,
DialTimeout: s.etcdTimeout,
})
if err != nil {
log.Errorf("connect to etcd error: %v", err)
time.Sleep(s.etcdTimeout)
continue
}
s.etcdClient = cli
log.Debugf("inited client to %s", s.etcdEndpoints)
break
}
// wait and set s.desired init value
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
resp, err := s.etcdClient.Get(ctx, PsDesired)
cancel()
if err != nil {
log.Errorf("getting %s error: %v", PsDesired, err)
time.Sleep(s.etcdTimeout)
continue
}
if len(resp.Kvs) != 0 {
s.desired, err = strconv.Atoi(string(resp.Kvs[0].Value))
if err != nil {
log.Errorf("value of %s invalid %v\n", PsDesired, err)
time.Sleep(s.etcdTimeout)
// NOTE: wait util ps_desired value change
continue
}
break
}
}
// try register pserver node on etcd
for {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
_, err := s.registerPserverEtcd(ctx)
cancel()
if err != nil {
log.Warn(err)
time.Sleep(s.etcdTimeout)
continue
}
break
}
} // if endpoints != ""
// Bypass etcd registration if no endpoints specified
return s, nil
}
// registerPserverEtcd registers pserver node on etcd using transaction.
func (s *Service) registerPserverEtcd(ctx context.Context) (*clientv3.TxnResponse, error) {
return concurrency.NewSTM(s.etcdClient, func(c concurrency.STM) error {
registered := false
for i := 0; i < s.desired; i++ {
psKey := "/ps/" + strconv.Itoa(i)
log.Debugf("checking %s", psKey)
ps := c.Get(psKey)
log.Debugf("got value (%s) for key: %s", ps, psKey)
if ps == "" {
resp, err := s.etcdClient.Grant(context.TODO(), 5)
if err != nil {
log.Fatal(err)
}
// find the first id and write info
c.Put(psKey, s.externalIP, clientv3.WithLease(resp.ID))
log.Debugf("set pserver node %s with value %s", psKey, s.externalIP)
_, kaerr := s.etcdClient.KeepAlive(context.TODO(), resp.ID)
if kaerr != nil {
log.Errorf("keepalive etcd node error: %v", kaerr)
return kaerr
}
log.Debug("register finished")
registered = true
break
}
}
if registered == true {
return nil
}
return errors.New("not registerd, may due to already have enough pservers")
}, concurrency.WithAbortContext(ctx), concurrency.WithIsolation(concurrency.RepeatableReads))
} }
// InitParam initializes a parameter. // InitParam initializes a parameter.
......
...@@ -10,12 +10,15 @@ import ( ...@@ -10,12 +10,15 @@ import (
) )
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
var p pserver.Parameter var p pserver.Parameter
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -72,8 +75,11 @@ func TestFull(t *testing.T) { ...@@ -72,8 +75,11 @@ func TestFull(t *testing.T) {
} }
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
err := s.FinishInitParams(0, nil) if err != nil {
t.Error(err)
}
err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) { ...@@ -85,15 +91,18 @@ func TestMultipleInit(t *testing.T) {
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
err := s.SendGrad(pserver.Gradient{}, nil) err = s.SendGrad(pserver.Gradient{}, nil)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
} }
} }
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService() s, err := pserver.NewService("", time.Second*5)
if err != nil {
t.Error(err)
}
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
errCh := make(chan error, 2) errCh := make(chan error, 2)
var wg sync.WaitGroup var wg sync.WaitGroup
...@@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -133,7 +142,7 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil) err = s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
......
package networkhelper
import (
"errors"
"net"
)
// GetExternalIP returns the ip address of local network interface, not the
// loopback device.
func GetExternalIP() (string, error) {
ifaces, err := net.Interfaces()
if err != nil {
return "", err
}
for _, iface := range ifaces {
if iface.Flags&net.FlagUp == 0 {
continue // interface down
}
if iface.Flags&net.FlagLoopback != 0 {
continue // loopback interface
}
addrs, err := iface.Addrs()
if err != nil {
return "", err
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil || ip.IsLoopback() {
continue
}
ip = ip.To4()
if ip == nil {
continue // not an ipv4 address
}
return ip.String(), nil
}
}
return "", errors.New("are you connected to the network?")
}
package networkhelper
import "testing"
func TestGetIP(t *testing.T) {
_, err := GetExternalIP()
if err != nil {
t.Errorf("GetExternalIP returns error : %v\n", err)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册