client.go 2.8 KB
Newer Older
1 2 3
package master

import (
H
Helin Wang 已提交
4
	"os"
5 6

	"github.com/PaddlePaddle/Paddle/go/connection"
H
Helin Wang 已提交
7
	"github.com/PaddlePaddle/recordio"
H
Helin Wang 已提交
8
	log "github.com/sirupsen/logrus"
9 10 11 12 13
)

// Client is the client of the master server.
type Client struct {
	conn *connection.Conn
H
Helin Wang 已提交
14
	ch   chan []byte
15 16 17
}

// NewClient creates a new Client.
18 19 20
//
// bufSize is the record buffer size. NextRecord will read from this
// buffer.
21
func NewClient(addrCh <-chan string, bufSize int) *Client {
22 23
	c := &Client{}
	c.conn = connection.New()
24
	c.ch = make(chan []byte, bufSize)
25
	go c.monitorMaster(addrCh)
H
Helin Wang 已提交
26
	go c.getRecords()
27 28 29
	return c
}

H
Helin Wang 已提交
30 31 32 33
func (c *Client) getRecords() {
	for {
		t, err := c.getTask()
		if err != nil {
H
Helin Wang 已提交
34 35
			// TODO(helin): wait before move on with next
			// getTask call.
H
Helin Wang 已提交
36
			log.Errorln(err)
H
Helin Wang 已提交
37 38 39 40 41 42
			continue
		}

		for _, chunk := range t.Chunks {
			f, err := os.Open(chunk.Path)
			if err != nil {
H
Helin Wang 已提交
43
				log.Errorln(err)
H
Helin Wang 已提交
44 45 46 47 48 49 50 51
				continue
			}

			s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
			for s.Scan() {
				c.ch <- s.Record()
			}

52
			if s.Err() != nil {
H
Helin Wang 已提交
53
				log.Errorln(err, chunk.Path)
54 55
			}

H
Helin Wang 已提交
56 57
			err = f.Close()
			if err != nil {
H
Helin Wang 已提交
58
				log.Errorln(err)
H
Helin Wang 已提交
59 60
			}
		}
61 62 63 64

		// We treat a task as finished whenever the last data
		// instance of the task is read. This is not exactly
		// correct, but a reasonable approximation.
H
Helin Wang 已提交
65 66 67 68
		c.taskFinished(t.ID)
	}
}

69
func (c *Client) monitorMaster(addrCh <-chan string) {
70
	lastMaster := ""
71
	for curMaster := range addrCh {
H
Helin Wang 已提交
72
		// connect to the new address once address changed.
73 74 75 76
		if curMaster != lastMaster {
			if curMaster == "" {
				err := c.conn.Close()
				if err != nil {
H
Helin Wang 已提交
77
					log.Errorln(err)
78 79 80 81
				}
			} else {
				err := c.conn.Connect(curMaster)
				if err != nil {
H
Helin Wang 已提交
82
					log.Errorln(err)
83 84 85 86 87 88 89 90 91 92 93 94

					// connect to addr failed, set
					// to last known addr in order
					// to retry next time.
					curMaster = lastMaster
				}
			}
		}
		lastMaster = curMaster
	}
}

95 96 97 98 99 100 101 102
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
	return c.conn.Call("Service.SetDataset", globPaths, nil)
}

H
Helin Wang 已提交
103 104
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
105
	var t Task
106
	err := c.conn.Call("Service.GetTask", 0, &t)
107 108 109 110
	return t, err
}

// TaskFinished tells the master server a task is finished.
H
Helin Wang 已提交
111
func (c *Client) taskFinished(taskID int) error {
112
	return c.conn.Call("Service.TaskFinished", taskID, nil)
113
}
H
Helin Wang 已提交
114

G
gongweibao 已提交
115
// TaskFailed tell the master server as task is failed.
G
gongweibao 已提交
116 117
func (c *Client) taskFailed(taskID TaskID) error {
	return c.conn.Call("Service.TaskFinished", taskID, nil)
G
gongweibao 已提交
118 119
}

H
Helin Wang 已提交
120 121
// NextRecord returns next record in the dataset.
//
H
Helin Wang 已提交
122
// NextRecord will block until the next record is available. It is
H
Helin Wang 已提交
123 124 125 126
// thread-safe.
func (c *Client) NextRecord() []byte {
	return <-c.ch
}