client.go 2.6 KB
Newer Older
1 2 3 4
package master

import (
	"log"
H
Helin Wang 已提交
5
	"os"
6 7 8
	"time"

	"github.com/PaddlePaddle/Paddle/go/connection"
H
Helin Wang 已提交
9
	"github.com/PaddlePaddle/recordio"
10 11 12 13 14 15 16 17 18 19
)

// Addresser provide the address of the master server.
type Addresser interface {
	Address() string
}

// Client is the client of the master server.
type Client struct {
	conn *connection.Conn
H
Helin Wang 已提交
20
	ch   chan []byte
21 22 23
}

// NewClient creates a new Client.
H
Helin Wang 已提交
24 25 26 27
//
// bufSize is the record buffer size. NextRecord will read from the
// buffer.
func NewClient(addr Addresser, bufSize int) *Client {
28 29
	c := &Client{}
	c.conn = connection.New()
H
Helin Wang 已提交
30
	c.ch = make(chan []byte, bufSize)
31
	go c.monitorMaster(addr)
H
Helin Wang 已提交
32
	go c.getRecords()
33 34 35
	return c
}

H
Helin Wang 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
func (c *Client) getRecords() {
	for {
		t, err := c.getTask()
		if err != nil {
			log.Println(err)
			continue
		}

		for _, chunk := range t.Chunks {
			f, err := os.Open(chunk.Path)
			if err != nil {
				log.Println(err)
				continue
			}

			s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
			for s.Scan() {
				c.ch <- s.Record()
			}

			err = f.Close()
			if err != nil {
				log.Println(err)
			}
		}
		c.taskFinished(t.ID)
	}
}

65 66 67
func (c *Client) monitorMaster(addr Addresser) {
	lastMaster := ""
	monitor := func() {
H
Helin Wang 已提交
68 69
		// get the lastest address of the master server,
		// connect to the new address once address changed.
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
		curMaster := addr.Address()
		if curMaster != lastMaster {
			if curMaster == "" {
				err := c.conn.Close()
				if err != nil {
					log.Println(err)
				}
			} else {
				err := c.conn.Connect(curMaster)
				if err != nil {
					log.Println(err)

					// connect to addr failed, set
					// to last known addr in order
					// to retry next time.
					curMaster = lastMaster
				}

			}
		}

		lastMaster = curMaster
	}

	monitor()
	ticker := time.NewTicker(10 * time.Second)
	for _ = range ticker.C {
		monitor()
	}
}

101 102 103 104 105 106 107 108
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
	return c.conn.Call("Service.SetDataset", globPaths, nil)
}

H
Helin Wang 已提交
109 110
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
111
	var t Task
112
	err := c.conn.Call("Service.GetTask", 0, &t)
113 114 115 116
	return t, err
}

// TaskFinished tells the master server a task is finished.
H
Helin Wang 已提交
117
func (c *Client) taskFinished(taskID int) error {
118
	return c.conn.Call("Service.TaskFinished", taskID, nil)
119
}
H
Helin Wang 已提交
120 121 122 123 124 125 126 127

// NextRecord returns next record in the dataset.
//
// NextRecord will block until next record is available. It is
// thread-safe.
func (c *Client) NextRecord() []byte {
	return <-c.ch
}