diff --git a/doc/design/cluster_train/pserver_client.md b/doc/design/cluster_train/pserver_client.md index 007285640e9f11c55715291774826620419cec66..b3e4079010490b69db1de28157f0cab80cad2381 100644 --- a/doc/design/cluster_train/pserver_client.md +++ b/doc/design/cluster_train/pserver_client.md @@ -136,6 +136,9 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad /** * @brief paddle_get_params gets parameters from parameter servers. * + * paddle_get_params will block until parameters are initialized on + * the parameter servers. + * * @param names the array of names of the parameters to get. * @param dst the destination array of parameters to save to. * @param len the length of the names array and the paddle_parameter diff --git a/go/pserver/cclient/cclient.go b/go/pserver/cclient/cclient.go index 1b5560451a89056dd71939a072d809bf7f00fc3f..0b4aa79806b72f4608230d2216d1741389913d95 100644 --- a/go/pserver/cclient/cclient.go +++ b/go/pserver/cclient/cclient.go @@ -39,6 +39,7 @@ import "C" import ( "log" + "strings" "sync" "unsafe" @@ -86,29 +87,46 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { return (*[1 << 30]byte)(p)[:len:len] } +type selector bool + +func (s selector) Select() bool { + return bool(s) +} + +type lister []pserver.Server + +func (l lister) List() []pserver.Server { + return l +} + //export paddle_new_pserver_client -func paddle_new_pserver_client(addr *C.char) C.client { - c := pserver.NewClient(C.GoString(addr)) +func paddle_new_pserver_client(addrs *C.char, selected int) C.client { + a := C.GoString(addrs) + as := strings.Split(a, ",") + servers := make([]pserver.Server, len(as)) + for i := range as { + servers[i].Index = i + servers[i].Addr = as[i] + } + c := pserver.NewClient(lister(servers), len(as), selector(selected != 0)) return add(c) } +//export paddle_new_etcd_pserver_client +func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.client { + // TODO(helin): fault tolerant pserver client using etcd. + panic("not implemented.") +} + //export paddle_pserver_client_release func paddle_pserver_client_release(client C.client) { - c := remove(client) - c.Cleanup() + remove(client) } //export paddle_begin_init_params -func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int { +func paddle_begin_init_params(client C.client) C.int { c := get(client) - b := cArrayToSlice(pserver_config, int(config_len)) - selected, err := c.BeginInitParams(b) - if err != nil { - log.Println(err) - return -1 - } - - if selected { + if selected := c.BeginInitParams(); selected { return 1 } return 0 @@ -230,7 +248,7 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter func paddle_save_model(client C.client, path *C.char) C.int { p := C.GoString(path) c := get(client) - err := c.SaveModel(p) + err := c.Save(p) if err != nil { log.Println(err) return -1 diff --git a/go/pserver/cclient/test/main.c b/go/pserver/cclient/test/main.c index abfb32e5603f5b51036f1f8475c8e3aca2f05ccb..c14037235c144e1193e0ed2a4c1b01787b92b202 100644 --- a/go/pserver/cclient/test/main.c +++ b/go/pserver/cclient/test/main.c @@ -11,9 +11,9 @@ void fail() { int main() { char addr[] = "localhost:3000"; - client c = paddle_new_pserver_client(addr); + client c = paddle_new_pserver_client(addr, 1); retry: - if (paddle_begin_init_params(c, NULL, 0)) { + if (paddle_begin_init_params(c)) { paddle_parameter param; char name_a[] = "param_a"; char name_b[] = "param_b"; diff --git a/go/pserver/client.go b/go/pserver/client.go index 1c98aea6d1c429a7b51510ddee76ff2700d4a688..f8bd0aa59f30ec7e2b2d318929af96135d3128ed 100644 --- a/go/pserver/client.go +++ b/go/pserver/client.go @@ -1,12 +1,82 @@ package pserver +import ( + "hash/fnv" + "log" + "sort" + "time" + + "github.com/PaddlePaddle/Paddle/go/pserver/internal/connection" +) + +// TODO(helin): add RPC call retry logic + +// Selector selects if the client should initialize parameter servers. +type Selector interface { + Select() bool +} + +// Server is the identification of a parameter Server. +type Server struct { + Index int + Addr string +} + +// Lister lists currently available parameter servers. +type Lister interface { + List() []Server +} + // Client is the client to parameter servers. type Client struct { + sel Selector + pservers []*connection.Conn } // NewClient creates a new client. -func NewClient(addr string) *Client { - return &Client{} +func NewClient(l Lister, pserverNum int, sel Selector) *Client { + c := &Client{sel: sel} + c.pservers = make([]*connection.Conn, pserverNum) + for i := 0; i < pserverNum; i++ { + c.pservers[i] = connection.New() + } + go c.monitorPservers(l, pserverNum) + return c +} + +// monitorPservers monitors pserver addresses, and updates connection +// when the address changes. +func (c *Client) monitorPservers(l Lister, pserverNum int) { + knownServers := make([]Server, pserverNum) + ticker := time.NewTicker(10 * time.Second) + monitor := func() { + curServers := make([]Server, pserverNum) + list := l.List() + for _, l := range list { + curServers[l.Index] = l + } + + for i := range knownServers { + if knownServers[i].Addr != curServers[i].Addr { + err := c.pservers[i].Connect(curServers[i].Addr) + if err != nil { + log.Println(err) + + // connect to addr failed, set + // to last known addr in order + // to retry next time. + curServers[i].Addr = knownServers[i].Addr + } + } + } + + knownServers = curServers + } + + monitor() + for _ = range ticker.C { + monitor() + } } // BeginInitParams begins to initialize parameters on parameter @@ -17,38 +87,146 @@ func NewClient(addr string) *Client { // servers. Other trainers will be blocked until the initialization is // done, and they need to get the initialized parameters from // parameter servers using GetParams. -func (c *Client) BeginInitParams(pserverConfigProto []byte) (selected bool, err error) { - return true, nil +func (c *Client) BeginInitParams() bool { + return c.sel.Select() } // InitParam initializes the parameter on parameter servers. func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { - return nil + var dummy int + return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy) } // FinishInitParams tells parameter servers client has sent all // parameters to parameter servers as initialization. func (c *Client) FinishInitParams() error { + for _, p := range c.pservers { + var dummy int + err := p.Call("Service.FinishInitParams", dummy, &dummy) + if err != nil { + return err + } + } return nil } // SendGrads sends gradients to parameter servers for updating // parameters. func (c *Client) SendGrads(grads []Gradient) error { + errCh := make(chan error, len(grads)) + for _, g := range grads { + go func(g Gradient) { + var dummy int + err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy) + errCh <- err + }(g) + } + + recv := 0 + for err := range errCh { + if err != nil { + return err + } + + recv++ + if recv == len(grads) { + break + } + } return nil } +type result struct { + idx int + param Parameter + err error +} + +type results []result + +func (r results) Len() int { + return len(r) +} + +func (r results) Less(i int, j int) bool { + return r[i].idx < r[j].idx +} + +func (r results) Swap(i int, j int) { + r[i], r[j] = r[j], r[i] +} + // GetParams gets parameters from parameter servers. func (c *Client) GetParams(names []string) ([]Parameter, error) { - return nil, nil + rCh := make(chan result, len(names)) + + for idx, name := range names { + go func(name string, idx int) { + var parameter Parameter + err := c.pservers[c.partition(name)].Call("Service.GetParam", name, ¶meter) + rCh <- result{idx: idx, param: parameter, err: err} + }(name, idx) + } + + var rs results + recv := 0 + for r := range rCh { + if r.err != nil { + return nil, r.err + } + rs = append(rs, r) + + recv++ + if recv == len(names) { + break + } + } + sort.Sort(rs) + + ps := make([]Parameter, len(rs)) + for i := range rs { + ps[i] = rs[i].param + } + + return ps, nil } -// SaveModel indicates parameters to save the parameter to the given -// path. -func (c *Client) SaveModel(path string) error { +// Save indicates parameters to save the parameter to the given path. +func (c *Client) Save(path string) error { + errCh := make(chan error, len(c.pservers)) + + for _, p := range c.pservers { + var dummy int + err := p.Call("Service.Save", path, &dummy) + errCh <- err + } + + recv := 0 + for err := range errCh { + if err != nil { + return err + } + + recv++ + if recv == len(c.pservers) { + break + } + } + + // TODO(helin): there will be many files under path, need to + // merge them into a single file. return nil } -// Cleanup cleans up the client states. -func (c *Client) Cleanup() { +func strHash(s string) uint32 { + h := fnv.New32a() + h.Write([]byte(s)) + return h.Sum32() +} + +// TODO(helin): now partition only select which parameter server to +// send the entire parameter. We need to partition a parameter into +// small blocks and send to different parameter servers. +func (c *Client) partition(key string) int { + return int(strHash(key) % uint32(len(c.pservers))) } diff --git a/go/pserver/client_test.go b/go/pserver/client_test.go new file mode 100644 index 0000000000000000000000000000000000000000..a9a0948a51a31a1c7393f716e3dfc436dbf919af --- /dev/null +++ b/go/pserver/client_test.go @@ -0,0 +1,123 @@ +package pserver_test + +import ( + "net" + "net/http" + "net/rpc" + "strconv" + "strings" + "testing" + + "github.com/PaddlePaddle/Paddle/go/pserver" +) + +const numPserver = 10 + +var port [numPserver]int + +func init() { + for i := 0; i < numPserver; i++ { + l, err := net.Listen("tcp", ":0") + if err != nil { + panic(err) + } + + ss := strings.Split(l.Addr().String(), ":") + p, err := strconv.Atoi(ss[len(ss)-1]) + if err != nil { + panic(err) + } + port[i] = p + + go func(l net.Listener) { + s := pserver.NewService() + server := rpc.NewServer() + err := server.Register(s) + if err != nil { + panic(err) + } + + mux := http.NewServeMux() + mux.Handle(rpc.DefaultRPCPath, server) + err = http.Serve(l, mux) + if err != nil { + panic(err) + } + }(l) + } +} + +type selector bool + +func (s selector) Select() bool { + return bool(s) +} + +type lister []pserver.Server + +func (l lister) List() []pserver.Server { + return l +} + +func TestClientFull(t *testing.T) { + servers := make([]pserver.Server, numPserver) + for i := 0; i < numPserver; i++ { + servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])} + } + c := pserver.NewClient(lister(servers), len(servers), selector(true)) + selected := c.BeginInitParams() + if !selected { + t.Fatal("should be selected.") + } + + const numParameter = 100 + for i := 0; i < numParameter; i++ { + var p pserver.Parameter + p.Name = "p_" + strconv.Itoa(i) + p.ElementType = pserver.Float32 + p.Content = make([]byte, (i+1)*100) + err := c.InitParam(pserver.ParameterWithConfig{Param: p}) + if err != nil { + t.Fatal(err) + } + } + + err := c.FinishInitParams() + if err != nil { + t.Fatal(err) + } + + var grads []pserver.Gradient + for i := 0; i < numParameter/2; i++ { + var g pserver.Gradient + g.Name = "p_" + strconv.Itoa(i) + g.ElementType = pserver.Float32 + g.Content = make([]byte, (i+1)*100) + grads = append(grads, g) + } + + err = c.SendGrads(grads) + if err != nil { + t.Fatal(err) + } + + names := make([]string, numParameter) + for i := 0; i < numParameter; i++ { + names[i] = "p_" + strconv.Itoa(i) + } + + params, err := c.GetParams(names) + if err != nil { + t.Fatal(err) + } + + if len(names) != len(params) { + t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params)) + } + + for i := range params { + if names[i] != params[i].Name { + t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i]) + } + } +} diff --git a/go/pserver/internal/connection/conn.go b/go/pserver/internal/connection/conn.go new file mode 100644 index 0000000000000000000000000000000000000000..1c04f117254054741b7d45fb16462b5ce84a2aea --- /dev/null +++ b/go/pserver/internal/connection/conn.go @@ -0,0 +1,84 @@ +package connection + +import ( + "errors" + "net/rpc" + "sync" +) + +// TODO(helin): add TCP re-connect logic + +// Conn is a connection to a parameter server +type Conn struct { + mu sync.Mutex + client *rpc.Client + waitConn chan struct{} +} + +// New creates a new connection. +func New() *Conn { + c := &Conn{} + return c +} + +// Connect connects the connection to a address. +func (c *Conn) Connect(addr string) error { + c.mu.Lock() + if c.client != nil { + err := c.client.Close() + if err != nil { + c.mu.Unlock() + return err + } + + c.client = nil + } + c.mu.Unlock() + + client, err := rpc.DialHTTP("tcp", addr) + if err != nil { + return err + } + + c.mu.Lock() + defer c.mu.Unlock() + + if c.client == nil { + c.client = client + if c.waitConn != nil { + close(c.waitConn) + c.waitConn = nil + } + } else { + return errors.New("client already set from a concurrent goroutine") + } + + return nil +} + +// Call make a RPC call. +// +// Call will be blocked until the connection to remote RPC service +// being established. +func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error { + c.mu.Lock() + client := c.client + var waitCh chan struct{} + if client == nil { + if c.waitConn != nil { + waitCh = c.waitConn + } else { + waitCh = make(chan struct{}) + c.waitConn = waitCh + } + } + c.mu.Unlock() + + if waitCh != nil { + // wait until new connection being established + <-waitCh + return c.Call(serviceMethod, args, reply) + } + + return client.Call(serviceMethod, args, reply) +} diff --git a/go/pserver/optimizer.go b/go/pserver/optimizer.go index 64bdefe660aaba7f53b5f3b6ee1cb9c0484baedb..417f8c509388055028bd46e42501741298308193 100644 --- a/go/pserver/optimizer.go +++ b/go/pserver/optimizer.go @@ -29,11 +29,11 @@ func newOptimizer(t optimizerType, learning_rate float64) *optimizer { func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { if len(p.Content) != len(g.Content) { - return fmt.Errorf("parameter and gradient length not match, parameter: %d, gradient: %d", len(p.Content), len(g.Content)) + return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, len(p.Content), len(g.Content)) } if p.ElementType != g.ElementType { - return fmt.Errorf("parameter and gradient element type not match, parameter: %v, gradient: %v", p.ElementType, g.ElementType) + return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType) } r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))) diff --git a/go/pserver/service.go b/go/pserver/service.go index f43e59403a71cb5bed2187c2f2f80465642a5c65..d5787b9708bb15629a6e6290ffc97ee9885bc8b8 100644 --- a/go/pserver/service.go +++ b/go/pserver/service.go @@ -49,33 +49,12 @@ type Service struct { // NewService creates a new service. func NewService() *Service { - s := &Service{} + s := &Service{opt: newOptimizer(sgd, 0.01)} s.paramMap = make(map[string]Parameter) s.initialized = make(chan struct{}) return s } -// BeginInitParams tells the parameter server that the parameter -// initialization has begun. -func (s *Service) BeginInitParams(config []byte, dummy *int) error { - select { - case <-s.initialized: - return ErrAlreadyInitialized - default: - } - - s.mu.Lock() - defer s.mu.Unlock() - - if s.opt != nil { - s.opt.Cleanup() - } - - // TODO(helin): parse learning rate from config - s.opt = newOptimizer(sgd, 0.01) - return nil -} - // InitParam initializes a parameter. func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { select { @@ -109,75 +88,45 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { return nil } -// SendGrads sends gradients to parameter servers for parameter +// SendGrad sends gradient to parameter servers for parameter // optimization. -func (s *Service) SendGrads(grads []Gradient, dummy *int) error { +func (s *Service) SendGrad(g Gradient, dummy *int) error { select { case <-s.initialized: default: return ErrUninitialized } - count := len(grads) - if count == 0 { - return nil - } - s.mu.Lock() defer s.mu.Unlock() - for _, g := range grads { - if _, ok := s.paramMap[g.Name]; !ok { - return fmt.Errorf("parameter: %s does not exist", g.Name) - } - } - - errCh := make(chan error, count) - for _, g := range grads { - go func(p Parameter, g Gradient) { - err := s.opt.UpdateParameter(p, g) - errCh <- err - }(s.paramMap[g.Name], g) + p, ok := s.paramMap[g.Name] + if !ok { + return fmt.Errorf("parameter: %s does not exist", g.Name) } - recv := 0 - for err := range errCh { - if err != nil { - return err - } - - recv++ - if recv == count { - break - } - } - return nil + return s.opt.UpdateParameter(p, g) } -// GetParams gets parameters from the parameter server. -func (s *Service) GetParams(names []string, parameters *[]Parameter) error { +// GetParam gets parameters from the parameter server. +func (s *Service) GetParam(name string, parameter *Parameter) error { <-s.initialized s.mu.Lock() defer s.mu.Unlock() - for _, n := range names { - if _, ok := s.paramMap[n]; !ok { - return fmt.Errorf("parameter: %s does not exist", n) - } - } - - *parameters = make([]Parameter, len(names)) - for i, n := range names { - // The parameter content (a byte slice) may change - // during RPC serialization due to write from other - // goroutine, we allow it since mini-batch based deep - // learning optimization methods are stochastic in - // nature. This race condition is allowed deliberately - // to save the program from making a copy of the - // paramter content. - (*parameters)[i] = s.paramMap[n] + p, ok := s.paramMap[name] + if !ok { + return fmt.Errorf("parameter: %s does not exist", name) } + // The parameter content (a byte slice) may change + // during RPC serialization due to write from other + // goroutine, we allow it since mini-batch based deep + // learning optimization methods are stochastic in + // nature. This race condition is allowed deliberately + // to save the program from making a copy of the + // paramter content. + *parameter = p return nil } diff --git a/go/pserver/service_test.go b/go/pserver/service_test.go index c58ccf9231d5d2d0bbfaf1e583f7b1ecc79e11cd..4c9fac4536e09013916aadb26af3a86a5a775b4f 100644 --- a/go/pserver/service_test.go +++ b/go/pserver/service_test.go @@ -4,23 +4,19 @@ import ( "reflect" "sync" "testing" + "time" "github.com/PaddlePaddle/Paddle/go/pserver" ) func TestFull(t *testing.T) { s := pserver.NewService() - var dummy int - err := s.BeginInitParams(nil, &dummy) - if err != nil { - t.FailNow() - } - var p pserver.Parameter p.Name = "param_a" p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.ElementType = pserver.Int32 - err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy) + var dummy int + err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy) if err != nil { t.FailNow() } @@ -39,40 +35,39 @@ func TestFull(t *testing.T) { t.FailNow() } - var params []pserver.Parameter - err = s.GetParams([]string{"param_b", "param_a"}, ¶ms) + var param pserver.Parameter + err = s.GetParam("param_b", ¶m) if err != nil { t.FailNow() } - if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) { + if !reflect.DeepEqual(param, p1) { t.FailNow() } - grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)} - err = s.SendGrads(grads, &dummy) + g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) + err = s.SendGrad(g1, &dummy) if err != nil { t.FailNow() } + err = s.SendGrad(g2, &dummy) - var params1 []pserver.Parameter - err = s.GetParams([]string{"param_b", "param_a"}, ¶ms1) if err != nil { t.FailNow() } - if len(params) != 2 { + var param1 pserver.Parameter + err = s.GetParam("param_a", ¶m1) + if err != nil { t.FailNow() } // don't compare content, since it's already changed by // gradient update. - params1[0].Content = nil - params1[0].Content = nil + param1.Content = nil p.Content = nil - p1.Content = nil - if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) { + if !reflect.DeepEqual(param1, p) { t.FailNow() } } @@ -80,19 +75,7 @@ func TestFull(t *testing.T) { func TestMultipleInit(t *testing.T) { s := pserver.NewService() var dummy int - err := s.BeginInitParams(nil, &dummy) - if err != nil { - t.FailNow() - } - - // this is fine, it's possible for client to call init - // multiple times. - err = s.BeginInitParams(nil, &dummy) - if err != nil { - t.FailNow() - } - - err = s.FinishInitParams(0, &dummy) + err := s.FinishInitParams(0, &dummy) if err != nil { t.FailNow() } @@ -101,17 +84,12 @@ func TestMultipleInit(t *testing.T) { if err != pserver.ErrAlreadyInitialized { t.FailNow() } - - err = s.BeginInitParams(nil, &dummy) - if err != pserver.ErrAlreadyInitialized { - t.FailNow() - } } func TestUninitialized(t *testing.T) { s := pserver.NewService() var dummy int - err := s.SendGrads(nil, &dummy) + err := s.SendGrad(pserver.Gradient{}, &dummy) if err != pserver.ErrUninitialized { t.FailNow() } @@ -123,8 +101,8 @@ func TestBlockUntilInitialized(t *testing.T) { var wg sync.WaitGroup wg.Add(1) go func() { - var params []pserver.Parameter - err := s.GetParams(nil, ¶ms) + var param pserver.Parameter + err := s.GetParam("param_a", ¶m) if err != nil { t.FailNow() } @@ -143,11 +121,7 @@ func TestBlockUntilInitialized(t *testing.T) { ch <- struct{}{} }() - var dummy int - err := s.BeginInitParams(nil, &dummy) - if err != nil { - t.FailNow() - } + time.Sleep(50 * time.Millisecond) select { case <-ch: @@ -156,6 +130,16 @@ func TestBlockUntilInitialized(t *testing.T) { default: } + var p pserver.Parameter + p.Name = "param_a" + p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} + p.ElementType = pserver.Int32 + var dummy int + err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy) + if err != nil { + t.FailNow() + } + err = s.FinishInitParams(0, &dummy) if err != nil { t.FailNow()