diff --git a/CMakeLists.txt b/CMakeLists.txt index 5349f59805ba35bb03d876e4f7279840c8f8641c..5bedbbefa85a730ff2934a12597988a67e73c1a4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,7 @@ include(coveralls) # set code coverage include_directories("${PROJ_ROOT}") include_directories("${PROJ_ROOT}/paddle/cuda/include") include_directories("${CMAKE_CURRENT_BINARY_DIR}/proto") -include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/cclient") +include_directories("${CMAKE_CURRENT_BINARY_DIR}/go/pserver/client/c") include_directories(${Boost_INCLUDE_DIRS}) set(EXTERNAL_LIBS diff --git a/go/CMakeLists.txt b/go/CMakeLists.txt index 014697d1555859e4d74c55604f8d65d7abe4cbbf..f00c70a0589a4f41a23164a95d505d4310d9157b 100644 --- a/go/CMakeLists.txt +++ b/go/CMakeLists.txt @@ -13,7 +13,7 @@ # limitations under the License. # -add_subdirectory(pserver/cclient) +add_subdirectory(pserver/client/c) add_subdirectory(cmd/pserver) add_subdirectory(cmd/master) add_subdirectory(master/c) diff --git a/go/cmd/pserver/pserver.go b/go/cmd/pserver/pserver.go index 8a42d4f8af1713e246f9efaf5dc7ba878c3b271e..31ef450f032f756fb32a0444a7e94a18ec2918a0 100644 --- a/go/cmd/pserver/pserver.go +++ b/go/cmd/pserver/pserver.go @@ -15,6 +15,7 @@ import ( func main() { port := flag.Int("port", 0, "port of the pserver") + index := flag.Int("index", -1, "index of this pserver, should be larger or equal than 0") etcdEndpoint := flag.String("etcd-endpoint", "http://127.0.0.1:2379", "comma separated endpoint string for pserver to connect to etcd") etcdTimeout := flag.Int("etcd-timeout", 5, "timeout for etcd calls") @@ -29,11 +30,16 @@ func main() { } log.SetLevel(level) - timeout := time.Second * time.Duration((*etcdTimeout)) - e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) - idx, err := e.Register() - if err != nil { - panic(err) + var idx int + if *index >= 0 { + idx = *index + } else { + timeout := time.Second * time.Duration((*etcdTimeout)) + e := pserver.NewEtcdClient(*etcdEndpoint, *numPservers, timeout) + idx, err = e.Register() + if err != nil { + panic(err) + } } s, err := pserver.NewService(idx) diff --git a/go/master/etcd_client.go b/go/master/etcd_client.go index e27c014792f31ca27fe1a1636d69acccc4206ea3..04c1394e963d1eb541b80b91407fb55b0d1e1f2a 100644 --- a/go/master/etcd_client.go +++ b/go/master/etcd_client.go @@ -50,7 +50,7 @@ func NewEtcdClient(endpoints []string, addr string, lockPath, addrPath, statePat lock := concurrency.NewMutex(sess, lockPath) // It's fine for the lock to get stuck, in this case we have // multiple master servers running (only configured to have - // one master running, but split-brain problem may cuase + // one master running, but split-brain problem may cause // multiple master servers running), and the cluster management // software will kill one of them. log.Debugf("Trying to acquire lock at %s.", lockPath) @@ -98,7 +98,7 @@ func (e *EtcdClient) Save(state []byte) error { // We lost the master lock and can not acquire // it back, it means some other master is // already started. We don't want cluster - // managment system to kill the master server + // management system to kill the master server // who is holding the lock and running // correctly. So the most feasible solution is // to kill current master server. The current diff --git a/go/pserver/cclient/CMakeLists.txt b/go/pserver/client/c/CMakeLists.txt similarity index 67% rename from go/pserver/cclient/CMakeLists.txt rename to go/pserver/client/c/CMakeLists.txt index 7fe74c62f109b186eb43383b78f30478b9be74c1..a3fcaeef190a178c1eed806f3e03a14ced780eef 100644 --- a/go/pserver/cclient/CMakeLists.txt +++ b/go/pserver/client/c/CMakeLists.txt @@ -1,5 +1,5 @@ cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) -go_library(paddle_pserver_cclient STATIC) +go_library(paddle_pserver_cclient STATIC DEPS paddle_go_optimizer) if(WITH_TESTING) add_subdirectory(test) endif() diff --git a/go/pserver/cclient/cclient.go b/go/pserver/client/c/cclient.go similarity index 88% rename from go/pserver/cclient/cclient.go rename to go/pserver/client/c/cclient.go index bbaf43d9f1434a278568bc110a709718b9b8c222..7ddaceb7ed33db32e19a191402100a0c0efa241a 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/client/c/cclient.go @@ -30,15 +30,16 @@ import ( "unsafe" "github.com/PaddlePaddle/Paddle/go/pserver" + "github.com/PaddlePaddle/Paddle/go/pserver/client" log "github.com/sirupsen/logrus" ) var nullPtr = unsafe.Pointer(uintptr(0)) var mu sync.Mutex -var handleMap = make(map[C.paddle_pserver_client]*pserver.Client) +var handleMap = make(map[C.paddle_pserver_client]*client.Client) var curHandle C.paddle_pserver_client -func add(c *pserver.Client) C.paddle_pserver_client { +func add(c *client.Client) C.paddle_pserver_client { mu.Lock() defer mu.Unlock() client := curHandle @@ -47,13 +48,13 @@ func add(c *pserver.Client) C.paddle_pserver_client { return client } -func get(client C.paddle_pserver_client) *pserver.Client { +func get(client C.paddle_pserver_client) *client.Client { mu.Lock() defer mu.Unlock() return handleMap[client] } -func remove(client C.paddle_pserver_client) *pserver.Client { +func remove(client C.paddle_pserver_client) *client.Client { mu.Lock() defer mu.Unlock() h := handleMap[client] @@ -80,9 +81,9 @@ func (s selector) Select() bool { return bool(s) } -type lister []pserver.Server +type lister []client.Server -func (l lister) List() []pserver.Server { +func (l lister) List() []client.Server { return l } @@ -90,19 +91,22 @@ func (l lister) List() []pserver.Server { func paddle_new_pserver_client(addrs *C.char, selected int) C.paddle_pserver_client { a := C.GoString(addrs) as := strings.Split(a, ",") - servers := make([]pserver.Server, len(as)) + servers := make([]client.Server, len(as)) for i := range as { servers[i].Index = i servers[i].Addr = as[i] } - c := pserver.NewClient(lister(servers), len(as), selector(selected != 0)) + c := client.NewClient(lister(servers), len(as), selector(selected != 0)) return add(c) } //export paddle_new_etcd_pserver_client -func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.paddle_pserver_client { - // TODO(helin): fault tolerant pserver client using etcd. - panic("not implemented.") +func paddle_new_etcd_pserver_client(etcd_endpoints *C.char, selected int) C.paddle_pserver_client { + // TODO(Longfei: use etcd lock to decide which trainer to initialize the parameters) + addr := C.GoString(etcd_endpoints) + etcd_client := client.NewEtcd(addr) + c := client.NewClient(etcd_client, etcd_client.Desired(), selector(selected != 0)) + return add(c) } //export paddle_pserver_client_release diff --git a/go/pserver/cclient/test/CMakeLists.txt b/go/pserver/client/c/test/CMakeLists.txt similarity index 100% rename from go/pserver/cclient/test/CMakeLists.txt rename to go/pserver/client/c/test/CMakeLists.txt diff --git a/go/pserver/cclient/test/test_cclient.c b/go/pserver/client/c/test/test_cclient.c similarity index 100% rename from go/pserver/cclient/test/test_cclient.c rename to go/pserver/client/c/test/test_cclient.c diff --git a/go/pserver/cclient/test/test_mnist.py b/go/pserver/client/c/test/test_mnist.py similarity index 100% rename from go/pserver/cclient/test/test_mnist.py rename to go/pserver/client/c/test/test_mnist.py diff --git a/go/pserver/cclient/test/test_train.py b/go/pserver/client/c/test/test_train.py similarity index 100% rename from go/pserver/cclient/test/test_train.py rename to go/pserver/client/c/test/test_train.py diff --git a/go/pserver/cclient/test/testdata/optimizer.pb b/go/pserver/client/c/test/testdata/optimizer.pb similarity index 100% rename from go/pserver/cclient/test/testdata/optimizer.pb rename to go/pserver/client/c/test/testdata/optimizer.pb diff --git a/go/pserver/client.go b/go/pserver/client/client.go similarity index 92% rename from go/pserver/client.go rename to go/pserver/client/client.go index 6938b9d5ce6f6d73c05bd6e3154777023965c319..aa8bfe30c26fcc0875ad479ecd562700ccefa5a3 100644 --- a/go/pserver/client.go +++ b/go/pserver/client/client.go @@ -1,4 +1,4 @@ -package pserver +package client import ( "errors" @@ -7,6 +7,7 @@ import ( "time" "github.com/PaddlePaddle/Paddle/go/connection" + "github.com/PaddlePaddle/Paddle/go/pserver" log "github.com/sirupsen/logrus" ) @@ -105,7 +106,7 @@ func (c *Client) BeginInitParams() bool { } // InitParam initializes the parameter on parameter servers. -func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { +func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error { return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil) } @@ -123,13 +124,13 @@ func (c *Client) FinishInitParams() error { // SendGrads sends gradients to parameter servers for updating // parameters. -func (c *Client) SendGrads(grads []Gradient) error { +func (c *Client) SendGrads(grads []pserver.Gradient) error { if len(grads) == 0 { return errors.New("no gradient received") } errCh := make(chan error, len(grads)) for _, g := range grads { - go func(g Gradient) { + go func(g pserver.Gradient) { err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil) errCh <- err }(g) @@ -151,7 +152,7 @@ func (c *Client) SendGrads(grads []Gradient) error { type result struct { idx int - param Parameter + param pserver.Parameter err error } @@ -170,12 +171,12 @@ func (r results) Swap(i int, j int) { } // GetParams gets parameters from parameter servers. -func (c *Client) GetParams(names []string) ([]Parameter, error) { +func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) { rCh := make(chan result, len(names)) for idx, name := range names { go func(name string, idx int) { - var parameter Parameter + var parameter pserver.Parameter err := c.pservers[c.partition(name)].Call("Service.GetParam", name, ¶meter) rCh <- result{idx: idx, param: parameter, err: err} }(name, idx) @@ -196,7 +197,7 @@ func (c *Client) GetParams(names []string) ([]Parameter, error) { } sort.Sort(rs) - ps := make([]Parameter, len(rs)) + ps := make([]pserver.Parameter, len(rs)) for i := range rs { ps[i] = rs[i].param } diff --git a/go/pserver/client_test.go b/go/pserver/client/client_test.go similarity index 54% rename from go/pserver/client_test.go rename to go/pserver/client/client_test.go index b805efa921630098f7ee2fcce8c02722d57d7485..29b400812c9dc3a5f44700eacbf7ba043248f2f2 100644 --- a/go/pserver/client_test.go +++ b/go/pserver/client/client_test.go @@ -1,6 +1,7 @@ -package pserver_test +package client_test import ( + "context" "io/ioutil" "net" "net/http" @@ -8,15 +9,25 @@ import ( "strconv" "strings" "testing" + "time" "github.com/PaddlePaddle/Paddle/go/pserver" + "github.com/PaddlePaddle/Paddle/go/pserver/client" + "github.com/coreos/etcd/clientv3" + log "github.com/sirupsen/logrus" ) -const numPserver = 10 +const ( + numPserver = 10 + etcdEndpoints = "127.0.0.1:2379" + timeout = 2 * time.Second +) -var port [numPserver]int +var pserverClientPorts [numPserver]int -func init() { +// this function init pserver client and return their ports in an array. +func initClient() [numPserver]int { + var ports [numPserver]int for i := 0; i < numPserver; i++ { l, err := net.Listen("tcp", ":0") if err != nil { @@ -28,7 +39,7 @@ func init() { if err != nil { panic(err) } - port[i] = p + ports[i] = p go func(l net.Listener) { s, err := pserver.NewService(0) @@ -49,6 +60,31 @@ func init() { } }(l) } + return ports +} + +func initNativeClient() { + pserverClientPorts = initClient() +} + +func initEtcdClient() { + client, err := clientv3.New(clientv3.Config{ + Endpoints: []string{etcdEndpoints}, + DialTimeout: time.Second * time.Duration(1), + }) + if err != nil { + log.Errorf("err %v", err) + } + ctx, cancel := context.WithTimeout(context.Background(), timeout) + client.Delete(ctx, pserver.PsDesired) + client.Delete(ctx, pserver.PsPath) + client.Put(ctx, pserver.PsDesired, strconv.Itoa(numPserver)) + ports := initClient() + for i := 0; i < numPserver; i++ { + client.Put(ctx, pserver.PsPath+strconv.Itoa(i), ":"+strconv.Itoa(ports[i])) + } + cancel() + client.Close() } type selector bool @@ -57,25 +93,20 @@ func (s selector) Select() bool { return bool(s) } -type lister []pserver.Server +type lister []client.Server -func (l lister) List() []pserver.Server { +func (l lister) List() []client.Server { return l } -func TestClientFull(t *testing.T) { - servers := make([]pserver.Server, numPserver) - for i := 0; i < numPserver; i++ { - servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])} - } - c := pserver.NewClient(lister(servers), len(servers), selector(true)) +func ClientTest(t *testing.T, c *client.Client) { selected := c.BeginInitParams() if !selected { t.Fatal("should be selected.") } const numParameter = 100 - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") + config, err := ioutil.ReadFile("./c/test/testdata/optimizer.pb") if err != nil { t.Fatalf("read optimizer proto failed") } @@ -129,3 +160,21 @@ func TestClientFull(t *testing.T) { } } } + +func TestNativeClient(t *testing.T) { + initNativeClient() + servers := make([]client.Server, numPserver) + for i := 0; i < numPserver; i++ { + servers[i] = client.Server{Index: i, Addr: ":" + strconv.Itoa(pserverClientPorts[i])} + } + c1 := client.NewClient(lister(servers), len(servers), selector(true)) + ClientTest(t, c1) +} + +// TODO: tmperary disable etcdClient test for dependency of etcd) +func EtcdClient(t *testing.T) { + initEtcdClient() + etcd_client := client.NewEtcd(etcdEndpoints) + c2 := client.NewClient(etcd_client, etcd_client.Desired(), selector(true)) + ClientTest(t, c2) +} diff --git a/go/pserver/client/etcd_client.go b/go/pserver/client/etcd_client.go new file mode 100644 index 0000000000000000000000000000000000000000..1fd3479aa88ccbbe7c5067da1e9886b65352e847 --- /dev/null +++ b/go/pserver/client/etcd_client.go @@ -0,0 +1,125 @@ +package client + +import ( + "context" + "strconv" + "strings" + "time" + + "github.com/PaddlePaddle/Paddle/go/pserver" + "github.com/coreos/etcd/clientv3" + log "github.com/sirupsen/logrus" +) + +const ( + DefaultEtcdTimeout time.Duration = 5 * time.Second +) + +// EtcdClient is used by pserver client that is a part of trainer process. +// TODO: +// 1. add watcher to watch the change state of pservers) +// 1. add etcd lock) +type EtcdClient struct { + client *clientv3.Client + timeout time.Duration + endpoints []string +} + +// Desired read ps desired number from etcd. +func (p *EtcdClient) Desired() int { + var psDesired int + for { + ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + resp, err := p.client.Get(ctx, pserver.PsDesired) + cancel() + if err != nil { + log.Errorf("Get ps dresire number failed! recnnectiong..., %v", err) + time.Sleep(p.timeout) + continue + } + + kvs := resp.Kvs + if len(kvs) == 0 { + log.Infoln("Waiting for ps desired registered ...") + time.Sleep(p.timeout) + continue + } + + psDesired, err = strconv.Atoi(string(resp.Kvs[0].Value)) + if err != nil { + log.Errorf("psDesired %s invalid %v", psDesired, err) + time.Sleep(p.timeout) + continue + } + + log.Debugf("Get psDesired number: %d", psDesired) + break + } + return psDesired +} + +// List return the pserver list read from etcd. +func (p *EtcdClient) List() []Server { + psDesired := p.Desired() + + servers := make([]Server, psDesired) + for { + for i := 0; i < psDesired; i++ { + ctx, cancel := context.WithTimeout(context.Background(), p.timeout) + cancel() + psKey := pserver.PsPath + strconv.Itoa(i) + log.Debugf("checking %s", psKey) + resp, err := p.client.Get(ctx, psKey) + if err != nil { + log.Infof("Get psKey= %s error, %v", psKey, err) + time.Sleep(p.timeout) + continue + } + kvs := resp.Kvs + if len(kvs) == 0 { + log.Infof("Waiting for ps addr registered ...") + time.Sleep(p.timeout) + continue + } + + psAddr := string(resp.Kvs[0].Value) + // TODO(Longfei) check the ps address + if psAddr == "" { + log.Infof("Get psKey = %s, psAddr is empty", psKey) + time.Sleep(p.timeout) + continue + } + log.Infof("got value (%s) for key: %s", psAddr, psKey) + servers[i].Index = i + servers[i].Addr = psAddr + } + break + } + return servers +} + +// NewEtcd create a etcd client to return the state of pserver on etcd. +func NewEtcd(endpoints string) *EtcdClient { + ep := strings.Split(endpoints, ",") + var cli *clientv3.Client + var err error + for { + cli, err = clientv3.New(clientv3.Config{ + Endpoints: ep, + DialTimeout: DefaultEtcdTimeout, + }) + if err != nil { + log.Errorf("Init etcd connection failed: %v", err) + time.Sleep(DefaultEtcdTimeout) + continue + } + break + } + log.Infof("Connected to etcd: %s\n", endpoints) + client := &EtcdClient{ + client: cli, + timeout: DefaultEtcdTimeout, + endpoints: ep, + } + return client +} diff --git a/go/pserver/etcd_client.go b/go/pserver/etcd_client.go index 4d88243edd4aa817ddc263ba316a3f6be9e1e67f..37b8d522c1bd07acb41b9515a6d9bc15eae9aa32 100644 --- a/go/pserver/etcd_client.go +++ b/go/pserver/etcd_client.go @@ -13,6 +13,13 @@ import ( log "github.com/sirupsen/logrus" ) +const ( + // PsDesired is etcd path for store desired pserver count + PsDesired = "/ps_desired" + // PsAddr is the base dir for pserver to store their addr + PsPath = "/ps/" +) + // EtcdClient is the etcd client that the pserver uses for fault // tolerance, service registry and coordination. type EtcdClient struct { @@ -68,7 +75,7 @@ func (e *EtcdClient) Register() (int, error) { // it at the same time. for { ctx, cancel := context.WithTimeout(context.Background(), time.Second) - _, err := e.initDesiredPsercers(ctx, e.numPservers) + _, err := e.initDesiredPservers(ctx, e.numPservers) cancel() if err != nil { log.Warn(err) @@ -120,7 +127,7 @@ func (e *EtcdClient) Register() (int, error) { return pserverIdx, nil } -func (e *EtcdClient) initDesiredPsercers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { +func (e *EtcdClient) initDesiredPservers(ctx context.Context, numPservers int) (*clientv3.TxnResponse, error) { return concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { dsStr := c.Get(PsDesired) if dsStr == "" { @@ -136,7 +143,7 @@ func (e *EtcdClient) registerPserverEtcd(ctx context.Context) (int, error) { _, err := concurrency.NewSTM(e.etcdClient, func(c concurrency.STM) error { registered := false for i := 0; i < e.desired; i++ { - psKey := "/ps/" + strconv.Itoa(i) + psKey := PsPath + strconv.Itoa(i) log.Debugf("checking %s", psKey) ps := c.Get(psKey) log.Debugf("got value (%s) for key: %s", ps, psKey) diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index b4a040f46bff5c25b193d41e5d36b59762891574..bca3718af32b35416e94606816569dd9e76eccb6 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -2,7 +2,7 @@ package pserver // #cgo CFLAGS: -I ../../ // //FIXME: ldflags contain "build" path -// #cgo LDFLAGS: ../../build/go/pserver/cclient/libpaddle_go_optimizer.a -lstdc++ +// #cgo LDFLAGS: ../../build/go/pserver/client/c/libpaddle_go_optimizer.a -lstdc++ // #include "paddle/optimizer/optimizer.h" // #include // #include diff --git a/go/pserver/optimizer_test.go b/go/pserver/optimizer_test.go index b99b5a5f0bfed4d780ea19b75ddaa4129be77bd5..0b2f4cfa41a630645c128ac13826de9d8b1d521b 100644 --- a/go/pserver/optimizer_test.go +++ b/go/pserver/optimizer_test.go @@ -11,7 +11,7 @@ func TestOptimizerCreateRelease(t *testing.T) { ElementType: Int32, } p.Content = []byte{1, 3} - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") + config, err := ioutil.ReadFile("./client/c/test/testdata/optimizer.pb") if err != nil { t.Fatalf("read optimizer proto failed") } diff --git a/go/pserver/service.go b/go/pserver/service.go index e15a4e5a58a3bb1a154157b1212d141478e96231..7711dc027e173e862f9b33e7a57224097026872c 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -24,9 +24,6 @@ const ( Float64 ) -// PsDesired is etcd path for store desired pserver count -const PsDesired = "/ps_desired" - // Parameter is a piece of data to sync with the parameter server. type Parameter struct { Name string diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index 30e3ac8ae1ccd1d6c9e71ed113cc2543e8c1e224..b6d20d2c8b7ba0ccd7ab46669a597a21dc11c381 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -10,6 +10,10 @@ import ( "github.com/PaddlePaddle/Paddle/go/pserver" ) +const ( + OptimizerConfig = "./client/c/test/testdata/optimizer.pb" +) + func TestServiceFull(t *testing.T) { s, err := pserver.NewService(0) if err != nil { @@ -19,7 +23,7 @@ func TestServiceFull(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") + config, err := ioutil.ReadFile(OptimizerConfig) if err != nil { t.Fatalf("read optimizer proto failed") } @@ -149,7 +153,7 @@ func TestBlockUntilInitialized(t *testing.T) { p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - config, err := ioutil.ReadFile("./cclient/test/testdata/optimizer.pb") + config, err := ioutil.ReadFile(OptimizerConfig) if err != nil { t.Fatalf("read optimizer proto failed") }