client.go 5.8 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Q
Qiao Longfei 已提交
15
package client
16

17
import (
D
dongzhihong 已提交
18
	"errors"
19 20 21 22
	"hash/fnv"
	"sort"
	"time"

23
	"github.com/PaddlePaddle/Paddle/go/connection"
Q
Qiao Longfei 已提交
24
	"github.com/PaddlePaddle/Paddle/go/pserver"
H
Helin Wang 已提交
25
	log "github.com/sirupsen/logrus"
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
)

// TODO(helin): add RPC call retry logic

// Selector selects if the client should initialize parameter servers.
type Selector interface {
	Select() bool
}

// Server is the identification of a parameter Server.
type Server struct {
	Index int
	Addr  string
}

// Lister lists currently available parameter servers.
type Lister interface {
	List() []Server
}

H
Helin Wang 已提交
46
// Client is the client to parameter servers.
47
type Client struct {
48 49
	sel      Selector
	pservers []*connection.Conn
50 51
}

H
Helin Wang 已提交
52
// NewClient creates a new client.
53 54 55 56 57 58 59 60 61 62 63 64 65
func NewClient(l Lister, pserverNum int, sel Selector) *Client {
	c := &Client{sel: sel}
	c.pservers = make([]*connection.Conn, pserverNum)
	for i := 0; i < pserverNum; i++ {
		c.pservers[i] = connection.New()
	}
	go c.monitorPservers(l, pserverNum)
	return c
}

// monitorPservers monitors pserver addresses, and updates connection
// when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) {
66
	lastServers := make([]Server, pserverNum)
67 68 69 70 71 72 73 74
	ticker := time.NewTicker(10 * time.Second)
	monitor := func() {
		curServers := make([]Server, pserverNum)
		list := l.List()
		for _, l := range list {
			curServers[l.Index] = l
		}

75
		for i := range lastServers {
H
Helin Wang 已提交
76 77 78
			if lastServers[i].Addr == curServers[i].Addr {
				continue
			}
79

H
Helin Wang 已提交
80 81
			if curServers[i].Addr == "" {
				err := c.pservers[i].Close()
82
				if err != nil {
H
Helin Wang 已提交
83
					log.Errorln(err)
84
				}
H
Helin Wang 已提交
85 86

				continue
87
			}
H
Helin Wang 已提交
88 89 90

			err := c.pservers[i].Connect(curServers[i].Addr)
			if err != nil {
H
Helin Wang 已提交
91
				log.Errorln(err)
H
Helin Wang 已提交
92 93 94 95 96 97 98

				// connect to addr failed, set
				// to last known addr in order
				// to retry next time.
				curServers[i].Addr = lastServers[i].Addr
			}

99 100
		}

101
		lastServers = curServers
102 103 104
	}

	monitor()
H
Helin Wang 已提交
105
	for range ticker.C {
106 107
		monitor()
	}
108 109
}

H
Helin Wang 已提交
110 111 112 113 114 115 116 117
// BeginInitParams begins to initialize parameters on parameter
// servers.
//
// BeginInitParams will be called from multiple trainers, only one
// trainer will be selected to initialize the parameters on parameter
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
118 119
func (c *Client) BeginInitParams() bool {
	return c.sel.Select()
120 121
}

H
Helin Wang 已提交
122
// InitParam initializes the parameter on parameter servers.
Q
Qiao Longfei 已提交
123
func (c *Client) InitParam(paramWithConfigs pserver.ParameterWithConfig) error {
124
	return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
125 126
}

H
Helin Wang 已提交
127 128
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
129
func (c *Client) FinishInitParams() error {
130
	for _, p := range c.pservers {
131
		err := p.Call("Service.FinishInitParams", 0, nil)
132 133 134 135
		if err != nil {
			return err
		}
	}
136 137 138
	return nil
}

H
Helin Wang 已提交
139 140
// SendGrads sends gradients to parameter servers for updating
// parameters.
Q
Qiao Longfei 已提交
141
func (c *Client) SendGrads(grads []pserver.Gradient) error {
142
	if len(grads) == 0 {
D
dongzhihong 已提交
143
		return errors.New("no gradient received")
144
	}
145 146
	errCh := make(chan error, len(grads))
	for _, g := range grads {
Q
Qiao Longfei 已提交
147
		go func(g pserver.Gradient) {
148
			err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
			errCh <- err
		}(g)
	}

	recv := 0
	for err := range errCh {
		if err != nil {
			return err
		}

		recv++
		if recv == len(grads) {
			break
		}
	}
164 165 166
	return nil
}

167
type result struct {
H
Helin Wang 已提交
168
	idx   int
Q
Qiao Longfei 已提交
169
	param pserver.Parameter
H
Helin Wang 已提交
170
	err   error
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
}

type results []result

func (r results) Len() int {
	return len(r)
}

func (r results) Less(i int, j int) bool {
	return r[i].idx < r[j].idx
}

func (r results) Swap(i int, j int) {
	r[i], r[j] = r[j], r[i]
}

H
Helin Wang 已提交
187
// GetParams gets parameters from parameter servers.
Q
Qiao Longfei 已提交
188
func (c *Client) GetParams(names []string) ([]pserver.Parameter, error) {
189 190 191 192
	rCh := make(chan result, len(names))

	for idx, name := range names {
		go func(name string, idx int) {
Q
Qiao Longfei 已提交
193
			var parameter pserver.Parameter
194
			err := c.pservers[c.partition(name)].Call("Service.GetParam", name, &parameter)
H
Helin Wang 已提交
195
			rCh <- result{idx: idx, param: parameter, err: err}
196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
		}(name, idx)
	}

	var rs results
	recv := 0
	for r := range rCh {
		if r.err != nil {
			return nil, r.err
		}
		rs = append(rs, r)

		recv++
		if recv == len(names) {
			break
		}
	}
	sort.Sort(rs)

Q
Qiao Longfei 已提交
214
	ps := make([]pserver.Parameter, len(rs))
215
	for i := range rs {
H
Helin Wang 已提交
216
		ps[i] = rs[i].param
217 218 219
	}

	return ps, nil
220 221
}

222 223 224 225 226
// Save indicates parameters to save the parameter to the given path.
func (c *Client) Save(path string) error {
	errCh := make(chan error, len(c.pservers))

	for _, p := range c.pservers {
227
		err := p.Call("Service.Save", path, nil)
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
		errCh <- err
	}

	recv := 0
	for err := range errCh {
		if err != nil {
			return err
		}

		recv++
		if recv == len(c.pservers) {
			break
		}
	}

	// TODO(helin): there will be many files under path, need to
	// merge them into a single file.
245 246 247
	return nil
}

248 249
func strHash(s string) uint32 {
	h := fnv.New32a()
H
Helin Wang 已提交
250
	_, _ = h.Write([]byte(s))
251 252 253 254 255 256 257 258
	return h.Sum32()
}

// TODO(helin): now partition only select which parameter server to
// send the entire parameter. We need to partition a parameter into
// small blocks and send to different parameter servers.
func (c *Client) partition(key string) int {
	return int(strHash(key) % uint32(len(c.pservers)))
259
}