client.go 5.7 KB
Newer Older
D
dongzhihong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15 16 17
package master

import (
H
Helin Wang 已提交
18
	"os"
19
	"time"
20 21

	"github.com/PaddlePaddle/Paddle/go/connection"
H
Helin Wang 已提交
22
	"github.com/PaddlePaddle/recordio"
23
	"github.com/coreos/etcd/clientv3"
H
Helin Wang 已提交
24
	log "github.com/sirupsen/logrus"
25 26 27 28
)

// Client is the client of the master server.
type Client struct {
29 30 31
	conn    *connection.Conn
	ch      chan record
	bufSize int
G
gongweibao 已提交
32 33 34 35 36
}

type record struct {
	r   []byte
	err error
37 38
}

39
// WithBuffer sets the client to buffer the training record.
40 41 42
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
43 44 45 46 47
func WithBuffer(bufSize int) func(*Client) error {
	return func(c *Client) error {
		if bufSize <= 0 {
			return nil
		}
48
		c.bufSize = bufSize
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
		return nil
	}
}

// WithAddr sets the client to use fixed master address.
func WithAddr(addr string) func(c *Client) error {
	return func(c *Client) error {
		ch := make(chan string, 1)
		ch <- addr
		go c.monitorMaster(ch)
		return nil
	}
}

// WithEtcd sets the client to use etcd for master discovery.
func WithEtcd(endpoints []string, timeout time.Duration) func(*Client) error {
	return func(c *Client) error {
		cli, err := clientv3.New(clientv3.Config{
			Endpoints:   endpoints,
			DialTimeout: timeout,
		})
		if err != nil {
			return err
		}

		ch := make(chan string, 1)
		a, err := GetKey(cli, DefaultAddrPath, timeout)
		if err != nil {
			return err
		}

		if a != "" {
			// Master is registered, send to the master address
			// channel.
			ch <- a
		}

		go watchKey(cli, DefaultAddrPath, ch)
		go c.monitorMaster(ch)
		return nil
	}
}

// NewClient creates a new Client.
func NewClient(opts ...func(*Client) error) (*Client, error) {
94 95
	c := &Client{}
	c.conn = connection.New()
96 97 98 99 100 101 102

	for _, opt := range opts {
		err := opt(c)
		if err != nil {
			return nil, err
		}
	}
103 104 105 106
	c.ch = make(chan record, c.bufSize)
	// FIXME: connection is created asyncrosly in monitorMaster go routine,
	//        ensure the connection is ready for use before calling c.addClient.
	time.Sleep(time.Second)
107
	return c, nil
108 109
}

110 111 112 113 114 115
// StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
	go c.getRecords(passID)
}

func (c *Client) getRecords(passID int) {
H
Helin Wang 已提交
116
	for {
117
		t, err := c.getTask(passID)
H
Helin Wang 已提交
118
		if err != nil {
119 120 121 122 123 124 125 126 127 128 129 130
			if err.Error() == ErrPassBefore.Error() ||
				err.Error() == ErrNoMoreAvailable.Error() ||
				err.Error() == ErrAllTaskFailed.Error() {
				c.ch <- record{nil, err}
				break
			}
			if err.Error() == ErrPassAfter.Error() {
				// wait util last pass finishes
				time.Sleep(time.Second * 3)
				continue
			}
			log.Errorf("getTask error: %s", err)
H
Helin Wang 已提交
131 132 133
		}

		for _, chunk := range t.Chunks {
134 135 136
			f, e := os.Open(chunk.Path)
			if e != nil {
				log.Errorln(e)
H
Helin Wang 已提交
137 138 139 140 141
				continue
			}

			s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
			for s.Scan() {
G
gongweibao 已提交
142
				c.ch <- record{s.Record(), nil}
H
Helin Wang 已提交
143 144
			}

145
			if s.Err() != nil {
G
gongweibao 已提交
146
				c.ch <- record{nil, s.Err()}
H
Helin Wang 已提交
147
				log.Errorln(err, chunk.Path)
148 149
			}

H
Helin Wang 已提交
150 151
			err = f.Close()
			if err != nil {
H
Helin Wang 已提交
152
				log.Errorln(err)
H
Helin Wang 已提交
153 154
			}
		}
155 156 157 158

		// We treat a task as finished whenever the last data
		// instance of the task is read. This is not exactly
		// correct, but a reasonable approximation.
H
Helin Wang 已提交
159 160 161 162
		err = c.taskFinished(t.Meta.ID)
		if err != nil {
			log.Errorln(err)
		}
H
Helin Wang 已提交
163 164 165
	}
}

166
func (c *Client) monitorMaster(addrCh <-chan string) {
167
	lastMaster := ""
168
	for curMaster := range addrCh {
H
Helin Wang 已提交
169
		// connect to the new address once address changed.
170 171 172 173
		if curMaster != lastMaster {
			if curMaster == "" {
				err := c.conn.Close()
				if err != nil {
H
Helin Wang 已提交
174
					log.Errorln(err)
175 176 177 178
				}
			} else {
				err := c.conn.Connect(curMaster)
				if err != nil {
H
Helin Wang 已提交
179
					log.Errorln(err)
180 181 182 183 184 185 186 187 188 189 190 191

					// connect to addr failed, set
					// to last known addr in order
					// to retry next time.
					curMaster = lastMaster
				}
			}
		}
		lastMaster = curMaster
	}
}

192 193 194 195
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
196
//
197
// After all tasks are done, another call of SetDataset will start another pass.
198
func (c *Client) SetDataset(globPaths []string) error {
199 200
	err := c.conn.Call("Service.SetDataset", globPaths, nil)
	return err
201 202
}

H
Helin Wang 已提交
203
// getTask gets a new task from the master server.
204
func (c *Client) getTask(passID int) (Task, error) {
205
	var t Task
206
	err := c.conn.Call("Service.GetTask", passID, &t)
207 208 209 210
	return t, err
}

// TaskFinished tells the master server a task is finished.
H
Helin Wang 已提交
211
func (c *Client) taskFinished(taskID int) error {
212
	return c.conn.Call("Service.TaskFinished", taskID, nil)
213
}
H
Helin Wang 已提交
214

G
gongweibao 已提交
215
// TaskFailed tell the master server as task is failed.
G
gongweibao 已提交
216
func (c *Client) taskFailed(meta TaskMeta) error {
G
gongweibao 已提交
217
	return c.conn.Call("Service.TaskFailed", meta, nil)
G
gongweibao 已提交
218 219
}

H
Helin Wang 已提交
220 221
// NextRecord returns next record in the dataset.
//
H
Helin Wang 已提交
222
// NextRecord will block until the next record is available. It is
H
Helin Wang 已提交
223
// thread-safe.
G
gongweibao 已提交
224 225 226
func (c *Client) NextRecord() ([]byte, error) {
	r := <-c.ch
	return r.r, r.err
H
Helin Wang 已提交
227
}
228 229 230 231 232 233 234 235

// RequestSaveModel requests the master server to approve the caller
// to save the model.
func (c *Client) RequestSaveModel(trainerID string, blockDur time.Duration) (bool, error) {
	var need bool
	err := c.conn.Call("Service.RequestSaveModel", SaveModelRequest{TrainerID: trainerID, BlockDur: blockDur}, &need)
	return need, err
}