client.go 5.0 KB
Newer Older
1 2
package pserver

3 4 5 6 7 8
import (
	"hash/fnv"
	"log"
	"sort"
	"time"

H
Helin Wang 已提交
9
	"github.com/PaddlePaddle/Paddle/go/pserver/internal/connection"
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
)

// TODO(helin): add RPC call retry logic

// Selector selects if the client should initialize parameter servers.
type Selector interface {
	Select() bool
}

// Server is the identification of a parameter Server.
type Server struct {
	Index int
	Addr  string
}

// Lister lists currently available parameter servers.
type Lister interface {
	List() []Server
}

H
Helin Wang 已提交
30
// Client is the client to parameter servers.
31
type Client struct {
32 33
	sel      Selector
	pservers []*connection.Conn
34 35
}

H
Helin Wang 已提交
36
// NewClient creates a new client.
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
func NewClient(l Lister, pserverNum int, sel Selector) *Client {
	c := &Client{sel: sel}
	c.pservers = make([]*connection.Conn, pserverNum)
	for i := 0; i < pserverNum; i++ {
		c.pservers[i] = connection.New()
	}
	go c.monitorPservers(l, pserverNum)
	return c
}

// monitorPservers monitors pserver addresses, and updates connection
// when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) {
	knownServers := make([]Server, pserverNum)
	ticker := time.NewTicker(10 * time.Second)
	monitor := func() {
		curServers := make([]Server, pserverNum)
		list := l.List()
		for _, l := range list {
			curServers[l.Index] = l
		}

		for i := range knownServers {
			if knownServers[i].Addr != curServers[i].Addr {
				err := c.pservers[i].Connect(curServers[i].Addr)
				if err != nil {
					log.Println(err)

					// connect to addr failed, set
					// to last known addr in order
					// to retry next time.
					curServers[i].Addr = knownServers[i].Addr
				}
			}
		}

		knownServers = curServers
	}

	monitor()
	for _ = range ticker.C {
		monitor()
	}
80 81
}

H
Helin Wang 已提交
82 83 84 85 86 87 88 89
// BeginInitParams begins to initialize parameters on parameter
// servers.
//
// BeginInitParams will be called from multiple trainers, only one
// trainer will be selected to initialize the parameters on parameter
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
90 91
func (c *Client) BeginInitParams() bool {
	return c.sel.Select()
92 93
}

H
Helin Wang 已提交
94
// InitParam initializes the parameter on parameter servers.
95
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
96 97
	var dummy int
	return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy)
98 99
}

H
Helin Wang 已提交
100 101
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
102
func (c *Client) FinishInitParams() error {
103 104 105 106 107 108 109
	for _, p := range c.pservers {
		var dummy int
		err := p.Call("Service.FinishInitParams", dummy, &dummy)
		if err != nil {
			return err
		}
	}
110 111 112
	return nil
}

H
Helin Wang 已提交
113 114
// SendGrads sends gradients to parameter servers for updating
// parameters.
115
func (c *Client) SendGrads(grads []Gradient) error {
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
	errCh := make(chan error, len(grads))
	for _, g := range grads {
		go func(g Gradient) {
			var dummy int
			err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy)
			errCh <- err
		}(g)
	}

	recv := 0
	for err := range errCh {
		if err != nil {
			return err
		}

		recv++
		if recv == len(grads) {
			break
		}
	}
136 137 138
	return nil
}

139
type result struct {
H
Helin Wang 已提交
140 141 142
	idx   int
	param Parameter
	err   error
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
}

type results []result

func (r results) Len() int {
	return len(r)
}

func (r results) Less(i int, j int) bool {
	return r[i].idx < r[j].idx
}

func (r results) Swap(i int, j int) {
	r[i], r[j] = r[j], r[i]
}

H
Helin Wang 已提交
159
// GetParams gets parameters from parameter servers.
160
func (c *Client) GetParams(names []string) ([]Parameter, error) {
161 162 163 164 165 166
	rCh := make(chan result, len(names))

	for idx, name := range names {
		go func(name string, idx int) {
			var parameter Parameter
			err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
H
Helin Wang 已提交
167
			rCh <- result{idx: idx, param: parameter, err: err}
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
		}(name, idx)
	}

	var rs results
	recv := 0
	for r := range rCh {
		if r.err != nil {
			return nil, r.err
		}
		rs = append(rs, r)

		recv++
		if recv == len(names) {
			break
		}
	}
	sort.Sort(rs)

	ps := make([]Parameter, len(rs))
	for i := range rs {
H
Helin Wang 已提交
188
		ps[i] = rs[i].param
189 190 191
	}

	return ps, nil
192 193
}

194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
	errCh := make(chan error, len(c.pservers))

	for _, p := range c.pservers {
		var dummy int
		err := p.Call("Service.Save", path, &dummy)
		errCh <- err
	}

	recv := 0
	for err := range errCh {
		if err != nil {
			return err
		}

		recv++
		if recv == len(c.pservers) {
			break
		}
	}

	// TODO(helin): there will be many files under path, need to
	// merge them into a single file.
218 219 220
	return nil
}

221 222 223 224 225 226 227 228 229 230 231
func strHash(s string) uint32 {
	h := fnv.New32a()
	h.Write([]byte(s))
	return h.Sum32()
}

// TODO(helin): now partition only select which parameter server to
// send the entire parameter. We need to partition a parameter into
// small blocks and send to different parameter servers.
func (c *Client) partition(key string) int {
	return int(strHash(key) % uint32(len(c.pservers)))
232
}