提交 c10121e1 编写于 作者: 武毅 提交者: GitHub

[Done] Sync master client between passes and fix recordio split (#2948)

* fix recordio split and task passes

* update for pre commit

* update

* update, still need to sync client wait for pass end.

* able to sync passes for task dispatching

* update to comment

* update

* fix yapf check

* why local pre-commit fails? version is the same

* fix race condition

* update

* fix race condition

* this still have duplicate problem in unit test

* update

* update

* update by comment

* update
上级 f96e2157
......@@ -22,9 +22,11 @@
hooks:
- id: clang-formater
- repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 16398aeccf263adaf53b2495eed0406347d76281
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks:
- id: go-fmt
types: [go]
- id: gometalinter
types: [go]
- id: go-fmt
types:
- go
- id: gometalinter
types:
- go
......@@ -18,7 +18,6 @@ package main
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#define PADDLE_MASTER_OK 0
#define PADDLE_MASTER_ERROR -1
......@@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) {
remove(client)
}
//export paddle_start_get_records
func paddle_start_get_records(client C.paddle_master_client, pass C.int) {
c := get(client)
c.StartGetRecords(int(pass))
}
//export paddle_set_dataset
func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int {
c := get(client)
......@@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
// paddle_next_record gets the nexts training record.
//
// returns number of bytes of the records if success, -1 if failed.
// returns number of bytes of the records if success, -1 if failed, -2 if pass end.
//
//export paddle_next_record
func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int {
c := get(client)
r, err := c.NextRecord()
if err != nil {
// Error
// TODO: return the type of error?
// NOTE: use errors to indicate pass ends
if err.Error() == master.ErrAllTaskFailed.Error() ||
err.Error() == master.ErrNoMoreAvailable.Error() ||
err.Error() == master.ErrPassBefore.Error() {
return -2
}
*record = (*C.uchar)(nil)
return -1
}
......
......@@ -16,7 +16,6 @@ package master
import (
"os"
"sync"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
......@@ -27,9 +26,9 @@ import (
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
ch chan record
initChOnce sync.Once
conn *connection.Conn
ch chan record
bufSize int
}
type record struct {
......@@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error {
if bufSize <= 0 {
return nil
}
c.initChOnce.Do(func() {
c.ch = make(chan record, bufSize)
go c.getRecords()
})
c.bufSize = bufSize
return nil
}
}
......@@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) {
if err != nil {
return nil, err
}
}
c.ch = make(chan record, c.bufSize)
// FIXME: connection is created asyncrosly in monitorMaster go routine,
// ensure the connection is ready for use before calling c.addClient.
time.Sleep(time.Second)
return c, nil
}
func (c *Client) getRecords() {
// StartGetRecords must be called at beginning of each pass
func (c *Client) StartGetRecords(passID int) {
go c.getRecords(passID)
}
func (c *Client) getRecords(passID int) {
for {
t, err := c.getTask()
t, err := c.getTask(passID)
if err != nil {
log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err)
time.Sleep(3 * time.Second)
continue
if err.Error() == ErrPassBefore.Error() ||
err.Error() == ErrNoMoreAvailable.Error() ||
err.Error() == ErrAllTaskFailed.Error() {
c.ch <- record{nil, err}
break
}
if err.Error() == ErrPassAfter.Error() {
// wait util last pass finishes
time.Sleep(time.Second * 3)
continue
}
log.Errorf("getTask error: %s", err)
}
for _, chunk := range t.Chunks {
f, err := os.Open(chunk.Path)
if err != nil {
log.Errorln(err)
f, e := os.Open(chunk.Path)
if e != nil {
log.Errorln(e)
continue
}
......@@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) {
}
}
// SetDataset set dataset for the master server to dispatch.
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times at one pass. But only the first call
// will be honored.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
// After all tasks are done, another call of SetDataset will start another pass.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
err := c.conn.Call("Service.SetDataset", globPaths, nil)
return err
}
// getTask gets a new task from the master server.
func (c *Client) getTask() (Task, error) {
func (c *Client) getTask(passID int) (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
err := c.conn.Call("Service.GetTask", passID, &t)
return t, err
}
......@@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error {
// NextRecord will block until the next record is available. It is
// thread-safe.
func (c *Client) NextRecord() ([]byte, error) {
c.initChOnce.Do(func() {
// initialize with in case WithBuffer is not used.
c.ch = make(chan record, 0)
go c.getRecords()
})
r := <-c.ch
return r.r, r.err
}
......
......@@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if err != nil {
panic(err)
s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1)
if sErr != nil {
panic(sErr)
}
server := rpc.NewServer()
err = server.Register(s)
if err != nil {
panic(err)
sErr = server.Register(s)
if sErr != nil {
panic(sErr)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
sErr = http.Serve(l, mux)
if sErr != nil {
panic(sErr)
}
}(l)
......@@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) {
ch := make(chan string, 1)
ch <- addr
go c.monitorMaster(ch)
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
......@@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) {
checkOnePass := func(i int) {
var tasks []Task
for idx := 0; idx < totalTask; idx++ {
task, err := c.getTask()
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
task, cErr := c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("error: %v, pass: %d\n", cErr, i)
}
tasks = append(tasks, task)
}
_, err = c.getTask()
if err == nil {
// getting task before task finishes should return error
_, cErr := c.getTask(i)
if cErr == nil {
t.Fatalf("Should get error, pass: %d\n", i)
}
err = c.taskFinished(tasks[0].Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(tasks[0].Meta.ID)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
err = c.taskFailed(tasks[0].Meta)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
// call taskFailed once won't put the task to failed queue, just ensure
// the call
cErr = c.taskFailed(tasks[0].Meta)
if cErr != nil {
t.Fatalf("Error: %v, pass: %d\n", cErr, i)
}
tasks = tasks[1:]
task, err := c.getTask()
if err != nil {
t.Fatal(err)
_, cErr = c.getTask(i)
if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() {
t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.taskFinished(task.Meta.ID)
if err != nil {
t.Fatalf("Error: %v, pass: %d\n", err, i)
cErr = c.taskFinished(task.Meta.ID)
if cErr != nil {
t.Fatal(cErr)
}
}
}
for i := 0; i < 10; i++ {
// init pass data
c.StartGetRecords(i)
checkOnePass(i)
}
}
......@@ -20,8 +20,10 @@ import (
"net/http"
"net/rpc"
"os"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
......@@ -29,6 +31,18 @@ import (
"github.com/PaddlePaddle/recordio"
)
// tool function for testing output goroutine ids
func goid() int {
var buf [64]byte
n := runtime.Stack(buf[:], false)
idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0]
id, err := strconv.Atoi(idField)
if err != nil {
panic(fmt.Sprintf("cannot get goroutine id: %v", err))
}
return id
}
func TestNextRecord(t *testing.T) {
const (
path = "/tmp/master_client_TestFull"
......@@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
go func(l net.Listener) {
s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1)
s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1)
if err != nil {
panic(err)
}
......@@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
w := recordio.NewWriter(f, -1, -1)
w := recordio.NewWriter(f, 1, -1)
for i := 0; i < total; i++ {
_, err = w.Write([]byte{byte(i)})
if err != nil {
......@@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) {
panic(err)
}
c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10))
if err != nil {
panic(err)
}
err = c.SetDataset([]string{path})
if err != nil {
panic(err)
}
for pass := 0; pass < 50; pass++ {
received := make(map[byte]bool)
for i := 0; i < total; i++ {
r, err := c.NextRecord()
if err != nil {
t.Fatal(pass, i, "Read error:", err)
// start several client to test task fetching
var wg sync.WaitGroup
for i := 0; i < 4; i++ {
wg.Add(1)
// test for multiple concurrent clients
go func() {
defer wg.Done()
// each go-routine needs a single client connection instance
c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1))
if e != nil {
t.Fatal(e)
}
if len(r) != 1 {
t.Fatal(pass, i, "Length should be 1.", r)
e = c.SetDataset([]string{path})
if e != nil {
panic(e)
}
if received[r[0]] {
t.Fatal(pass, i, "Received duplicate.", received, r)
// test for n passes
for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass)
received := make(map[byte]bool)
taskid := 0
for {
r, e := c.NextRecord()
if e != nil {
// ErrorPassAfter will wait, else break for next pass
if e.Error() == master.ErrPassBefore.Error() ||
e.Error() == master.ErrNoMoreAvailable.Error() {
break
}
t.Fatal(pass, taskid, "Read error:", e)
}
if len(r) != 1 {
t.Fatal(pass, taskid, "Length should be 1.", r)
}
if received[r[0]] {
t.Fatal(pass, taskid, "Received duplicate.", received, r)
}
taskid++
received[r[0]] = true
}
}
received[r[0]] = true
}
}()
}
wg.Wait()
}
......@@ -19,6 +19,7 @@ import (
"compress/gzip"
"encoding/gob"
"errors"
"math/rand"
"os"
"path/filepath"
"sync"
......@@ -33,6 +34,18 @@ const (
dialTimeout = 5 * time.Second
)
// ErrAllTaskFailed occur when tasks are in done or failed state.
var ErrAllTaskFailed = errors.New("all task finished")
// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail.
var ErrNoMoreAvailable = errors.New("no more available task")
// ErrPassBefore client side pass number does not match with master counter.
var ErrPassBefore = errors.New("pass number smaller than master")
// ErrPassAfter client side pass number does not match with master counter.
var ErrPassAfter = errors.New("pass number larger than master")
// Store is the interface for save and load the master state.
type Store interface {
Save([]byte) error
......@@ -75,17 +88,26 @@ type Service struct {
chunksPerTask int
timeoutDur time.Duration
failureMax int
ready chan struct{}
store Store
mu sync.Mutex
initDone bool
taskQueues taskQueues
ready chan struct{}
initDone bool
mu sync.Mutex
taskQueues taskQueues
currPass int
jobTasks []taskEntry
savingTrainer string
}
func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
id := 0
// generate uniq id across job using nanosecond + randint + counter
// FIXME(typhoonzero): this is a workaround, use uuid
randStart := rand.Int()
counter := 0
timestamp := time.Now().Nanosecond()
id := timestamp + randStart + counter
if chunksPerTask <= 0 {
chunksPerTask = 1
}
......@@ -95,7 +117,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
for i, c := range chunks {
if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 {
cur.Task.Meta.ID = id
id++
counter++
id = timestamp + randStart + counter
result = append(result, cur)
cur.Task.Chunks = nil
}
......@@ -266,19 +289,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error {
return err
}
s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
s.jobTasks = partition(chunks, s.chunksPerTask)
s.taskQueues.Todo = s.jobTasks
err = s.snapshot()
if err != nil {
log.Errorln(err)
return err
}
close(s.ready)
s.initDone = true
return nil
}
// processFailedTask retry s.failureMax times for failed task.
// return true if all task are done or failed.
func (s *Service) processFailedTask(t taskEntry, epoch int) {
if t.Task.Meta.Epoch != epoch {
// new epoch, task launched after the
......@@ -302,8 +327,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) {
return
}
log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure)
log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
return
}
func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
......@@ -331,37 +357,30 @@ func (s *Service) logFields() log.Fields {
}
// GetTask gets a new task from the service.
func (s *Service) GetTask(_ int, task *Task) error {
// passID is the client side pass count
func (s *Service) GetTask(passID int, task *Task) error {
select {
case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
if passID < s.currPass {
return ErrPassBefore
}
if passID > s.currPass {
// Client may get run to pass after master when one client faster than the
// other
return ErrPassAfter
}
if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
err := errors.New("all task failed")
log.WithFields(s.logFields()).Warningln("All tasks failed.")
return err
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.WithFields(s.logFields()).Warningln("No more available task.")
return err
if len(s.taskQueues.Done) == 0 && len(s.taskQueues.Pending) == 0 {
log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass")
return ErrAllTaskFailed
}
s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
log.WithFields(s.logFields()).Warningln("No more available task.")
return ErrNoMoreAvailable
}
t := s.taskQueues.Todo[0]
......@@ -381,7 +400,7 @@ func (s *Service) GetTask(_ int, task *Task) error {
}
// TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, _ *int) error {
func (s *Service) TaskFinished(taskID int, dummy *int) error {
select {
case <-s.ready:
}
......@@ -401,11 +420,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
delete(s.taskQueues.Pending, taskID)
log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID)
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 {
// increase master side pass count if all tasks finished
s.currPass++
s.taskQueues.Todo = s.jobTasks
s.taskQueues.Done = []taskEntry{}
// TODO(typhoonzero): deal with failed tasks
s.taskQueues.Failed = []taskEntry{}
log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.currPass)
}
err := s.snapshot()
......@@ -416,7 +438,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error {
}
// TaskFailed tells the service that a task is failed.
func (s *Service) TaskFailed(meta TaskMeta, _ *int) error {
func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error {
select {
case <-s.ready:
}
......
......@@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) {
cs := make([]Chunk, 100)
ts := partition(cs, 20)
for i := range ts {
if ts[i].Task.Meta.ID != i {
// test auto increament ids
if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 {
t.Error(ts[i], i)
}
}
......
......@@ -6,16 +6,19 @@ import cPickle as pickle
etcd_ip = os.getenv("MASTER_IP", "127.0.0.1")
etcd_endpoint = "http://" + etcd_ip + ":2379"
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
def cloud_reader():
print "connecting to master, etcd endpoints: ", etcd_endpoint
master_client = master.client(etcd_endpoint, 5, 64)
global master_client
master_client.set_dataset(
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"])
["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30)
while 1:
r, e = master_client.next_record()
if not r:
if e != -2: # other errors
print "get record error:", e
break
yield pickle.loads(r)
......@@ -27,10 +30,12 @@ def main():
# network config
x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13))
y_predict = paddle.layer.fc(input=x,
param_attr=paddle.attr.Param(name='w'),
param_attr=paddle.attr.Param(
name='w', learning_rate=1e-3),
size=1,
act=paddle.activation.Linear(),
bias_attr=paddle.attr.Param(name='b'))
bias_attr=paddle.attr.Param(
name='b', learning_rate=1e-3))
y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1))
cost = paddle.layer.mse_cost(input=y_predict, label=y)
......@@ -40,7 +45,6 @@ def main():
# create optimizer of new remote updater to pserver
optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3)
print "etcd endoint: ", etcd_endpoint
trainer = paddle.trainer.SGD(cost=cost,
parameters=parameters,
update_equation=optimizer,
......@@ -51,6 +55,8 @@ def main():
# event_handler to print training and testing info
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
# FIXME: for cloud data reader, pass number is managed by master
# should print the server side pass number
if event.batch_id % 100 == 0:
print "Pass %d, Batch %d, Cost %f" % (
event.pass_id, event.batch_id, event.cost)
......
......@@ -166,55 +166,37 @@ def cluster_files_reader(files_pattern,
return reader
def convert(output_path,
reader,
num_shards,
name_prefix,
max_lines_to_shuffle=1000):
def convert(output_path, reader, line_count, name_prefix):
import recordio
"""
Convert data from reader to recordio format files.
:param output_path: directory in which output files will be saved.
:param reader: a data reader, from which the convert program will read data instances.
:param num_shards: the number of shards that the dataset will be partitioned into.
:param name_prefix: the name prefix of generated files.
:param max_lines_to_shuffle: the max lines numbers to shuffle before writing.
"""
assert num_shards >= 1
assert max_lines_to_shuffle >= 1
def open_writers():
w = []
for i in range(0, num_shards):
n = "%s/%s-%05d-of-%05d" % (output_path, name_prefix, i,
num_shards - 1)
w.append(recordio.writer(n))
return w
def close_writers(w):
for i in range(0, num_shards):
w[i].close()
assert line_count >= 1
indx_f = 0
def write_data(w, lines):
def write_data(indx_f, lines):
random.shuffle(lines)
for i, d in enumerate(lines):
filename = "%s/%s-%05d" % (output_path, name_prefix, indx_f)
writer = recordio.writer(filename)
for l in lines:
# FIXME(Yancey1989):
# dumps with protocol: pickle.HIGHEST_PROTOCOL
o = pickle.dumps(d)
w[i % num_shards].write(o)
writer.write(cPickle.dumps(l))
writer.close()
w = open_writers()
lines = []
for i, d in enumerate(reader()):
lines.append(d)
if i % max_lines_to_shuffle == 0 and i >= max_lines_to_shuffle:
write_data(w, lines)
if i % line_count == 0 and i >= line_count:
write_data(indx_f, lines)
lines = []
indx_f += 1
continue
write_data(w, lines)
close_writers(w)
write_data(indx_f, lines)
......@@ -49,7 +49,6 @@ class client(object):
def set_dataset(self, paths):
holder_type = ctypes.c_char_p * len(paths)
holder = holder_type()
print paths
for idx, path in enumerate(paths):
c_ptr = ctypes.c_char_p(path)
holder[idx] = c_ptr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册