client.go 2.7 KB
Newer Older
1 2 3 4
package master

import (
	"log"
H
Helin Wang 已提交
5
	"os"
6 7 8
	"time"

	"github.com/PaddlePaddle/Paddle/go/connection"
H
Helin Wang 已提交
9
	"github.com/PaddlePaddle/recordio"
10 11 12 13 14 15 16 17 18 19
)

// Addresser provide the address of the master server.
type Addresser interface {
	Address() string
}

// Client is the client of the master server.
type Client struct {
	conn *connection.Conn
H
Helin Wang 已提交
20
	ch   chan []byte
21 22 23
}

// NewClient creates a new Client.
24
func NewClient(addr Addresser) *Client {
25 26
	c := &Client{}
	c.conn = connection.New()
27
	c.ch = make(chan []byte)
28
	go c.monitorMaster(addr)
H
Helin Wang 已提交
29
	go c.getRecords()
30 31 32
	return c
}

H
Helin Wang 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
func (c *Client) getRecords() {
	for {
		t, err := c.getTask()
		if err != nil {
			log.Println(err)
			continue
		}

		for _, chunk := range t.Chunks {
			f, err := os.Open(chunk.Path)
			if err != nil {
				log.Println(err)
				continue
			}

			s := recordio.NewRangeScanner(f, &chunk.Index, -1, -1)
			for s.Scan() {
				c.ch <- s.Record()
			}

53 54 55 56
			if s.Err() != nil {
				log.Println(err, chunk.Path)
			}

H
Helin Wang 已提交
57 58 59 60 61
			err = f.Close()
			if err != nil {
				log.Println(err)
			}
		}
62 63 64 65

		// We treat a task as finished whenever the last data
		// instance of the task is read. This is not exactly
		// correct, but a reasonable approximation.
H
Helin Wang 已提交
66 67 68 69
		c.taskFinished(t.ID)
	}
}

70 71 72
func (c *Client) monitorMaster(addr Addresser) {
	lastMaster := ""
	monitor := func() {
H
Helin Wang 已提交
73 74
		// get the lastest address of the master server,
		// connect to the new address once address changed.
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
		curMaster := addr.Address()
		if curMaster != lastMaster {
			if curMaster == "" {
				err := c.conn.Close()
				if err != nil {
					log.Println(err)
				}
			} else {
				err := c.conn.Connect(curMaster)
				if err != nil {
					log.Println(err)

					// connect to addr failed, set
					// to last known addr in order
					// to retry next time.
					curMaster = lastMaster
				}

			}
		}

		lastMaster = curMaster
	}

	monitor()
	ticker := time.NewTicker(10 * time.Second)
	for _ = range ticker.C {
		monitor()
	}
}

106 107 108 109 110 111 112 113
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
	return c.conn.Call("Service.SetDataset", globPaths, nil)
}

H
Helin Wang 已提交
114 115
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
116
	var t Task
117
	err := c.conn.Call("Service.GetTask", 0, &t)
118 119 120 121
	return t, err
}

// TaskFinished tells the master server a task is finished.
H
Helin Wang 已提交
122
func (c *Client) taskFinished(taskID int) error {
123
	return c.conn.Call("Service.TaskFinished", taskID, nil)
124
}
H
Helin Wang 已提交
125 126 127 128 129 130 131 132

// NextRecord returns next record in the dataset.
//
// NextRecord will block until next record is available. It is
// thread-safe.
func (c *Client) NextRecord() []byte {
	return <-c.ch
}