提交 0babf84b 编写于 作者: H Helin Wang

implement pserver RPC part, and simple parameter partition.

上级 7e93921a
...@@ -136,6 +136,9 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad ...@@ -136,6 +136,9 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad
/** /**
* @brief paddle_get_params gets parameters from parameter servers. * @brief paddle_get_params gets parameters from parameter servers.
* *
* paddle_get_params will block until parameters are initialized on
* the parameter servers.
*
* @param names the array of names of the parameters to get. * @param names the array of names of the parameters to get.
* @param dst the destination array of parameters to save to. * @param dst the destination array of parameters to save to.
* @param len the length of the names array and the paddle_parameter * @param len the length of the names array and the paddle_parameter
......
...@@ -39,6 +39,7 @@ import "C" ...@@ -39,6 +39,7 @@ import "C"
import ( import (
"log" "log"
"strings"
"sync" "sync"
"unsafe" "unsafe"
...@@ -86,29 +87,46 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte { ...@@ -86,29 +87,46 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return (*[1 << 30]byte)(p)[:len:len] return (*[1 << 30]byte)(p)[:len:len]
} }
type selector bool
func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
func (l lister) List() []pserver.Server {
return l
}
//export paddle_new_pserver_client //export paddle_new_pserver_client
func paddle_new_pserver_client(addr *C.char) C.client { func paddle_new_pserver_client(addrs *C.char, selected bool) C.client {
c := pserver.NewClient(C.GoString(addr)) a := C.GoString(addrs)
as := strings.Split(a, ",")
servers := make([]pserver.Server, len(as))
for i := range as {
servers[i].Index = i
servers[i].Addr = as[i]
}
c := pserver.NewClient(lister(servers), len(as), selector(selected))
return add(c) return add(c)
} }
//export paddle_new_etcd_pserver_client
func paddle_new_etcd_pserver_client(etcd_addr *C.char) C.client {
// TODO(helin): fault tolerant pserver client using etcd.
panic("not implemented.")
}
//export paddle_pserver_client_release //export paddle_pserver_client_release
func paddle_pserver_client_release(client C.client) { func paddle_pserver_client_release(client C.client) {
c := remove(client) remove(client)
c.Cleanup()
} }
//export paddle_begin_init_params //export paddle_begin_init_params
func paddle_begin_init_params(client C.client, pserver_config unsafe.Pointer, config_len C.int) C.int { func paddle_begin_init_params(client C.client) C.int {
c := get(client) c := get(client)
b := cArrayToSlice(pserver_config, int(config_len)) if selected := c.BeginInitParams(); selected {
selected, err := c.BeginInitParams(b)
if err != nil {
log.Println(err)
return -1
}
if selected {
return 1 return 1
} }
return 0 return 0
...@@ -230,7 +248,7 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter ...@@ -230,7 +248,7 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter
func paddle_save_model(client C.client, path *C.char) C.int { func paddle_save_model(client C.client, path *C.char) C.int {
p := C.GoString(path) p := C.GoString(path)
c := get(client) c := get(client)
err := c.SaveModel(p) err := c.Save(p)
if err != nil { if err != nil {
log.Println(err) log.Println(err)
return -1 return -1
......
package pserver package pserver
import (
"hash/fnv"
"log"
"sort"
"time"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver/internal/connection"
)
// TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers.
type Selector interface {
Select() bool
}
// Server is the identification of a parameter Server.
type Server struct {
Index int
Addr string
}
// Lister lists currently available parameter servers.
type Lister interface {
List() []Server
}
// Client is the client to parameter servers. // Client is the client to parameter servers.
type Client struct { type Client struct {
sel Selector
pservers []*connection.Conn
} }
// NewClient creates a new client. // NewClient creates a new client.
func NewClient(addr string) *Client { func NewClient(l Lister, pserverNum int, sel Selector) *Client {
return &Client{} c := &Client{sel: sel}
c.pservers = make([]*connection.Conn, pserverNum)
for i := 0; i < pserverNum; i++ {
c.pservers[i] = connection.New()
}
go c.monitorPservers(l, pserverNum)
return c
}
// monitorPservers monitors pserver addresses, and updates connection
// when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) {
knownServers := make([]Server, pserverNum)
ticker := time.NewTicker(10 * time.Second)
monitor := func() {
curServers := make([]Server, pserverNum)
list := l.List()
for _, l := range list {
curServers[l.Index] = l
}
for i := range knownServers {
if knownServers[i].Addr != curServers[i].Addr {
err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil {
log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers[i].Addr = knownServers[i].Addr
}
}
}
knownServers = curServers
}
monitor()
for _ = range ticker.C {
monitor()
}
} }
// BeginInitParams begins to initialize parameters on parameter // BeginInitParams begins to initialize parameters on parameter
...@@ -17,38 +87,146 @@ func NewClient(addr string) *Client { ...@@ -17,38 +87,146 @@ func NewClient(addr string) *Client {
// servers. Other trainers will be blocked until the initialization is // servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from // done, and they need to get the initialized parameters from
// parameter servers using GetParams. // parameter servers using GetParams.
func (c *Client) BeginInitParams(pserverConfigProto []byte) (selected bool, err error) { func (c *Client) BeginInitParams() bool {
return true, nil return c.sel.Select()
} }
// InitParam initializes the parameter on parameter servers. // InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
return nil var dummy int
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy)
} }
// FinishInitParams tells parameter servers client has sent all // FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization. // parameters to parameter servers as initialization.
func (c *Client) FinishInitParams() error { func (c *Client) FinishInitParams() error {
for _, p := range c.pservers {
var dummy int
err := p.Call("Service.FinishInitParams", dummy, &dummy)
if err != nil {
return err
}
}
return nil return nil
} }
// SendGrads sends gradients to parameter servers for updating // SendGrads sends gradients to parameter servers for updating
// parameters. // parameters.
func (c *Client) SendGrads(grads []Gradient) error { func (c *Client) SendGrads(grads []Gradient) error {
errCh := make(chan error, len(grads))
for _, g := range grads {
go func(g Gradient) {
var dummy int
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy)
errCh <- err
}(g)
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(grads) {
break
}
}
return nil return nil
} }
type result struct {
idx int
p Parameter
err error
}
type results []result
func (r results) Len() int {
return len(r)
}
func (r results) Less(i int, j int) bool {
return r[i].idx < r[j].idx
}
func (r results) Swap(i int, j int) {
r[i], r[j] = r[j], r[i]
}
// GetParams gets parameters from parameter servers. // GetParams gets parameters from parameter servers.
func (c *Client) GetParams(names []string) ([]Parameter, error) { func (c *Client) GetParams(names []string) ([]Parameter, error) {
return nil, nil rCh := make(chan result, len(names))
for idx, name := range names {
go func(name string, idx int) {
var parameter Parameter
err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
rCh <- result{idx: idx, p: parameter, err: err}
}(name, idx)
}
var rs results
recv := 0
for r := range rCh {
if r.err != nil {
return nil, r.err
}
rs = append(rs, r)
recv++
if recv == len(names) {
break
}
}
sort.Sort(rs)
ps := make([]Parameter, len(rs))
for i := range rs {
ps[i] = rs[i].p
}
return ps, nil
} }
// SaveModel indicates parameters to save the parameter to the given // Save indicates parameters to save the parameter to the given path.
// path. func (c *Client) Save(path string) error {
func (c *Client) SaveModel(path string) error { errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers {
var dummy int
err := p.Call("Service.Save", path, &dummy)
errCh <- err
}
recv := 0
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == len(c.pservers) {
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return nil return nil
} }
// Cleanup cleans up the client states. func strHash(s string) uint32 {
func (c *Client) Cleanup() { h := fnv.New32a()
h.Write([]byte(s))
return h.Sum32()
}
// TODO(helin): now partition only select which parameter server to
// send the entire parameter. We need to partition a parameter into
// small blocks and send to different parameter servers.
func (c *Client) partition(key string) int {
return int(strHash(key) % uint32(len(c.pservers)))
} }
...@@ -29,11 +29,11 @@ func newOptimizer(t optimizerType, learning_rate float64) *optimizer { ...@@ -29,11 +29,11 @@ func newOptimizer(t optimizerType, learning_rate float64) *optimizer {
func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error { func (o *optimizer) UpdateParameter(p Parameter, g Gradient) error {
if len(p.Content) != len(g.Content) { if len(p.Content) != len(g.Content) {
return fmt.Errorf("parameter and gradient length not match, parameter: %d, gradient: %d", len(p.Content), len(g.Content)) return fmt.Errorf("Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d", p.Name, len(p.Content), len(g.Content))
} }
if p.ElementType != g.ElementType { if p.ElementType != g.ElementType {
return fmt.Errorf("parameter and gradient element type not match, parameter: %v, gradient: %v", p.ElementType, g.ElementType) return fmt.Errorf("Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v", p.Name, p.ElementType, g.ElementType)
} }
r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content))) r := C.paddle_update_parameter(o.opt, unsafe.Pointer(&p.Content[0]), C.paddle_element_type(p.ElementType), unsafe.Pointer(&g.Content[0]), C.int(len(g.Content)))
......
...@@ -49,33 +49,12 @@ type Service struct { ...@@ -49,33 +49,12 @@ type Service struct {
// NewService creates a new service. // NewService creates a new service.
func NewService() *Service { func NewService() *Service {
s := &Service{} s := &Service{opt: newOptimizer(sgd, 0.01)}
s.paramMap = make(map[string]Parameter) s.paramMap = make(map[string]Parameter)
s.initialized = make(chan struct{}) s.initialized = make(chan struct{})
return s return s
} }
// BeginInitParams tells the parameter server that the parameter
// initialization has begun.
func (s *Service) BeginInitParams(config []byte, dummy *int) error {
select {
case <-s.initialized:
return ErrAlreadyInitialized
default:
}
s.mu.Lock()
defer s.mu.Unlock()
if s.opt != nil {
s.opt.Cleanup()
}
// TODO(helin): parse learning rate from config
s.opt = newOptimizer(sgd, 0.01)
return nil
}
// InitParam initializes a parameter. // InitParam initializes a parameter.
func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error { func (s *Service) InitParam(paramWithConfigs ParameterWithConfig, dummy *int) error {
select { select {
...@@ -109,75 +88,45 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { ...@@ -109,75 +88,45 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
return nil return nil
} }
// SendGrads sends gradients to parameter servers for parameter // SendGrad sends gradient to parameter servers for parameter
// optimization. // optimization.
func (s *Service) SendGrads(grads []Gradient, dummy *int) error { func (s *Service) SendGrad(g Gradient, dummy *int) error {
select { select {
case <-s.initialized: case <-s.initialized:
default: default:
return ErrUninitialized return ErrUninitialized
} }
count := len(grads)
if count == 0 {
return nil
}
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for _, g := range grads { p, ok := s.paramMap[g.Name]
if _, ok := s.paramMap[g.Name]; !ok { if !ok {
return fmt.Errorf("parameter: %s does not exist", g.Name) return fmt.Errorf("parameter: %s does not exist", g.Name)
}
}
errCh := make(chan error, count)
for _, g := range grads {
go func(p Parameter, g Gradient) {
err := s.opt.UpdateParameter(p, g)
errCh <- err
}(s.paramMap[g.Name], g)
} }
recv := 0 return s.opt.UpdateParameter(p, g)
for err := range errCh {
if err != nil {
return err
}
recv++
if recv == count {
break
}
}
return nil
} }
// GetParams gets parameters from the parameter server. // GetParam gets parameters from the parameter server.
func (s *Service) GetParams(names []string, parameters *[]Parameter) error { func (s *Service) GetParam(name string, parameter *Parameter) error {
<-s.initialized <-s.initialized
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for _, n := range names { p, ok := s.paramMap[name]
if _, ok := s.paramMap[n]; !ok { if !ok {
return fmt.Errorf("parameter: %s does not exist", n) return fmt.Errorf("parameter: %s does not exist", name)
}
}
*parameters = make([]Parameter, len(names))
for i, n := range names {
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
// learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
(*parameters)[i] = s.paramMap[n]
} }
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
// learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
*parameter = p
return nil return nil
} }
......
...@@ -4,23 +4,19 @@ import ( ...@@ -4,23 +4,19 @@ import (
"reflect" "reflect"
"sync" "sync"
"testing" "testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver" "github.com/PaddlePaddle/Paddle/go/pserver"
) )
func TestFull(t *testing.T) { func TestFull(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var dummy int
err := s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
var p pserver.Parameter var p pserver.Parameter
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
err = s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy) var dummy int
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -39,40 +35,39 @@ func TestFull(t *testing.T) { ...@@ -39,40 +35,39 @@ func TestFull(t *testing.T) {
t.FailNow() t.FailNow()
} }
var params []pserver.Parameter var param pserver.Parameter
err = s.GetParams([]string{"param_b", "param_a"}, &params) err = s.GetParam("param_b", &param)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
if len(params) != 2 || !reflect.DeepEqual(params[0], p1) || !reflect.DeepEqual(params[0], p1) { if !reflect.DeepEqual(param, p1) {
t.FailNow() t.FailNow()
} }
grads := []pserver.Gradient{pserver.Gradient(p1), pserver.Gradient(p)} g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrads(grads, &dummy) err = s.SendGrad(g1, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.SendGrad(g2, &dummy)
var params1 []pserver.Parameter
err = s.GetParams([]string{"param_b", "param_a"}, &params1)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
if len(params) != 2 { var param1 pserver.Parameter
err = s.GetParam("param_a", &param1)
if err != nil {
t.FailNow() t.FailNow()
} }
// don't compare content, since it's already changed by // don't compare content, since it's already changed by
// gradient update. // gradient update.
params1[0].Content = nil param1.Content = nil
params1[0].Content = nil
p.Content = nil p.Content = nil
p1.Content = nil
if !reflect.DeepEqual(params1[0], p1) || !reflect.DeepEqual(params1[0], p1) { if !reflect.DeepEqual(param1, p) {
t.FailNow() t.FailNow()
} }
} }
...@@ -80,19 +75,7 @@ func TestFull(t *testing.T) { ...@@ -80,19 +75,7 @@ func TestFull(t *testing.T) {
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var dummy int var dummy int
err := s.BeginInitParams(nil, &dummy) err := s.FinishInitParams(0, &dummy)
if err != nil {
t.FailNow()
}
// this is fine, it's possible for client to call init
// multiple times.
err = s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -101,17 +84,12 @@ func TestMultipleInit(t *testing.T) { ...@@ -101,17 +84,12 @@ func TestMultipleInit(t *testing.T) {
if err != pserver.ErrAlreadyInitialized { if err != pserver.ErrAlreadyInitialized {
t.FailNow() t.FailNow()
} }
err = s.BeginInitParams(nil, &dummy)
if err != pserver.ErrAlreadyInitialized {
t.FailNow()
}
} }
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var dummy int var dummy int
err := s.SendGrads(nil, &dummy) err := s.SendGrad(pserver.Gradient{}, &dummy)
if err != pserver.ErrUninitialized { if err != pserver.ErrUninitialized {
t.FailNow() t.FailNow()
} }
...@@ -123,8 +101,8 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -123,8 +101,8 @@ func TestBlockUntilInitialized(t *testing.T) {
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
var params []pserver.Parameter var param pserver.Parameter
err := s.GetParams(nil, &params) err := s.GetParam("param_a", &param)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -143,11 +121,7 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -143,11 +121,7 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{} ch <- struct{}{}
}() }()
var dummy int time.Sleep(50 * time.Millisecond)
err := s.BeginInitParams(nil, &dummy)
if err != nil {
t.FailNow()
}
select { select {
case <-ch: case <-ch:
...@@ -156,6 +130,16 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -156,6 +130,16 @@ func TestBlockUntilInitialized(t *testing.T) {
default: default:
} }
var p pserver.Parameter
p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32
var dummy int
err := s.InitParam(pserver.ParameterWithConfig{p, nil}, &dummy)
if err != nil {
t.FailNow()
}
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
......
package pserver_test
import (
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"testing"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)
const numPserver = 10
var port [numPserver]int
func init() {
for i := 0; i < numPserver; i++ {
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
port[i] = p
go func(l net.Listener) {
s := pserver.NewService()
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
}
}
type selector bool
func (s selector) Select() bool {
return bool(s)
}
type lister []pserver.Server
func (l lister) List() []pserver.Server {
return l
}
func TestClientFull(t *testing.T) {
servers := make([]pserver.Server, numPserver)
for i := 0; i < numPserver; i++ {
servers[i] = pserver.Server{Index: i, Addr: ":" + strconv.Itoa(port[i])}
}
c := pserver.NewClient(lister(servers), len(servers), selector(true))
selected := c.BeginInitParams()
if !selected {
t.Fatal("should be selected.")
}
const numParameter = 100
for i := 0; i < numParameter; i++ {
var p pserver.Parameter
p.Name = "p_" + strconv.Itoa(i)
p.ElementType = pserver.Float32
p.Content = make([]byte, (i+1)*100)
err := c.InitParam(pserver.ParameterWithConfig{Param: p})
if err != nil {
t.Fatal(err)
}
}
err := c.FinishInitParams()
if err != nil {
t.Fatal(err)
}
var grads []pserver.Gradient
for i := 0; i < numParameter/2; i++ {
var g pserver.Gradient
g.Name = "p_" + strconv.Itoa(i)
g.ElementType = pserver.Float32
g.Content = make([]byte, (i+1)*100)
grads = append(grads, g)
}
err = c.SendGrads(grads)
if err != nil {
t.Fatal(err)
}
names := make([]string, numParameter)
for i := 0; i < numParameter; i++ {
names[i] = "p_" + strconv.Itoa(i)
}
params, err := c.GetParams(names)
if err != nil {
t.Fatal(err)
}
if len(names) != len(params) {
t.Fatalf("parameter size not match, need: %d, have: %d", len(names), len(params))
}
for i := range params {
if names[i] != params[i].Name {
t.Fatalf("order of returned parameter does not required: parameter name: %s, required name: %s", names[i], params[i])
}
}
}
package connection
import (
"errors"
"net/rpc"
"sync"
)
// TODO(helin): add TCP re-connect logic
// Conn is a connection to a parameter server
type Conn struct {
mu sync.Mutex
client *rpc.Client
waitConn chan struct{}
}
// New creates a new connection.
func New() *Conn {
c := &Conn{}
return c
}
// Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error {
c.mu.Lock()
if c.client != nil {
err := c.client.Close()
if err != nil {
c.mu.Unlock()
return err
}
c.client = nil
}
c.mu.Unlock()
client, err := rpc.DialHTTP("tcp", addr)
if err != nil {
return err
}
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
c.client = client
if c.waitConn != nil {
close(c.waitConn)
c.waitConn = nil
}
} else {
return errors.New("client already set from a concurrent goroutine")
}
return nil
}
// Call make a RPC call.
//
// Call will be blocked until the connection to remote RPC service
// being established.
func (c *Conn) Call(serviceMethod string, args interface{}, reply interface{}) error {
c.mu.Lock()
client := c.client
var waitCh chan struct{}
if client == nil {
if c.waitConn != nil {
waitCh = c.waitConn
} else {
waitCh = make(chan struct{})
c.waitConn = waitCh
}
}
c.mu.Unlock()
if waitCh != nil {
// wait until new connection being established
<-waitCh
return c.Call(serviceMethod, args, reply)
}
return client.Call(serviceMethod, args, reply)
}
package pserver
type partitioner struct {
shardNum int
}
// partitioner partitions the parameters into shards.
func newPartitioner(shardNum int) *partitioner {
return &partitioner{shardNum: shardNum}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册