client.go 2.8 KB
Newer Older
1 2 3 4
package master

import (
	"log"
H
Helin Wang 已提交
5
	"os"
6 7 8
	"time"

	"github.com/PaddlePaddle/Paddle/go/connection"
H
Helin Wang 已提交
9
	"github.com/PaddlePaddle/recordio"
10 11 12 13 14 15 16 17 18 19
)

// Addresser provide the address of the master server.
type Addresser interface {
	Address() string
}

// Client is the client of the master server.
type Client struct {
	conn *connection.Conn
H
Helin Wang 已提交
20
	ch   chan []byte
21 22 23
}

// NewClient creates a new Client.
24
func NewClient(addr Addresser) *Client {
25 26
	c := &Client{}
	c.conn = connection.New()
27
	c.ch = make(chan []byte)
28
	go c.monitorMaster(addr)
H
Helin Wang 已提交
29
	go c.getRecords()
30 31 32
	return c
}

H
Helin Wang 已提交
33 34 35 36
func (c *Client) getRecords() {
	for {
		t, err := c.getTask()
		if err != nil {
H
Helin Wang 已提交
37 38
			// TODO(helin): wait before move on with next
			// getTask call.
H
Helin Wang 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
			log.Println(err)
			continue
		}

		for _, chunk := range t.Chunks {
			f, err := os.Open(chunk.Path)
			if err != nil {
				log.Println(err)
				continue
			}

			s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
			for s.Scan() {
				c.ch <- s.Record()
			}

55 56 57 58
			if s.Err() != nil {
				log.Println(err, chunk.Path)
			}

H
Helin Wang 已提交
59 60 61 62 63
			err = f.Close()
			if err != nil {
				log.Println(err)
			}
		}
64 65 66 67

		// We treat a task as finished whenever the last data
		// instance of the task is read. This is not exactly
		// correct, but a reasonable approximation.
H
Helin Wang 已提交
68 69 70 71
		c.taskFinished(t.ID)
	}
}

72 73 74
func (c *Client) monitorMaster(addr Addresser) {
	lastMaster := ""
	monitor := func() {
H
Helin Wang 已提交
75 76
		// get the lastest address of the master server,
		// connect to the new address once address changed.
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
		curMaster := addr.Address()
		if curMaster != lastMaster {
			if curMaster == "" {
				err := c.conn.Close()
				if err != nil {
					log.Println(err)
				}
			} else {
				err := c.conn.Connect(curMaster)
				if err != nil {
					log.Println(err)

					// connect to addr failed, set
					// to last known addr in order
					// to retry next time.
					curMaster = lastMaster
				}

			}
		}

		lastMaster = curMaster
	}

	monitor()
	ticker := time.NewTicker(10 * time.Second)
	for _ = range ticker.C {
		monitor()
	}
}

108 109 110 111 112 113 114 115
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
	return c.conn.Call("Service.SetDataset", globPaths, nil)
}

H
Helin Wang 已提交
116 117
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
118
	var t Task
119
	err := c.conn.Call("Service.GetTask", 0, &t)
120 121 122 123
	return t, err
}

// TaskFinished tells the master server a task is finished.
H
Helin Wang 已提交
124
func (c *Client) taskFinished(taskID int) error {
125
	return c.conn.Call("Service.TaskFinished", taskID, nil)
126
}
H
Helin Wang 已提交
127 128 129

// NextRecord returns next record in the dataset.
//
H
Helin Wang 已提交
130
// NextRecord will block until the next record is available. It is
H
Helin Wang 已提交
131 132 133 134
// thread-safe.
func (c *Client) NextRecord() []byte {
	return <-c.ch
}