From c10121e13c2309e81d1842c3ceca733b05f25e91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=AD=A6=E6=AF=85?= Date: Thu, 27 Jul 2017 13:40:25 +0800 Subject: [PATCH] [Done] Sync master client between passes and fix recordio split (#2948) * fix recordio split and task passes * update for pre commit * update * update, still need to sync client wait for pass end. * able to sync passes for task dispatching * update to comment * update * fix yapf check * why local pre-commit fails? version is the same * fix race condition * update * fix race condition * this still have duplicate problem in unit test * update * update * update by comment * update --- .pre-commit-config.yaml | 12 ++-- go/master/c/client.go | 17 +++-- go/master/client.go | 70 ++++++++++-------- go/master/client_internal_test.go | 60 ++++++++-------- go/master/client_test.go | 83 +++++++++++++++------- go/master/service.go | 98 ++++++++++++++++---------- go/master/service_internal_test.go | 3 +- go/pserver/client/c/test/test_train.py | 18 +++-- python/paddle/v2/dataset/common.py | 44 ++++-------- python/paddle/v2/master/client.py | 1 - 10 files changed, 235 insertions(+), 171 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index efb4dcb2df..980a97a07c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,9 +22,11 @@ hooks: - id: clang-formater - repo: https://github.com/PaddlePaddle/pre-commit-golang - sha: 16398aeccf263adaf53b2495eed0406347d76281 + sha: 8337620115c25ff8333f1b1a493bd031049bd7c0 hooks: - - id: go-fmt - types: [go] - - id: gometalinter - types: [go] + - id: go-fmt + types: + - go + - id: gometalinter + types: + - go diff --git a/go/master/c/client.go b/go/master/c/client.go index a2b18e4b47..b5759c30b1 100644 --- a/go/master/c/client.go +++ b/go/master/c/client.go @@ -18,7 +18,6 @@ package main #include #include #include - #define PADDLE_MASTER_OK 0 #define PADDLE_MASTER_ERROR -1 @@ -101,6 +100,12 @@ func paddle_release_master_client(client C.paddle_master_client) { remove(client) } +//export paddle_start_get_records +func paddle_start_get_records(client C.paddle_master_client, pass C.int) { + c := get(client) + c.StartGetRecords(int(pass)) +} + //export paddle_set_dataset func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int) C.int { c := get(client) @@ -121,15 +126,19 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int // paddle_next_record gets the nexts training record. // -// returns number of bytes of the records if success, -1 if failed. +// returns number of bytes of the records if success, -1 if failed, -2 if pass end. // //export paddle_next_record func paddle_next_record(client C.paddle_master_client, record **C.uchar) C.int { c := get(client) r, err := c.NextRecord() if err != nil { - // Error - // TODO: return the type of error? + // NOTE: use errors to indicate pass ends + if err.Error() == master.ErrAllTaskFailed.Error() || + err.Error() == master.ErrNoMoreAvailable.Error() || + err.Error() == master.ErrPassBefore.Error() { + return -2 + } *record = (*C.uchar)(nil) return -1 } diff --git a/go/master/client.go b/go/master/client.go index bbf3768d96..62801b9b7f 100644 --- a/go/master/client.go +++ b/go/master/client.go @@ -16,7 +16,6 @@ package master import ( "os" - "sync" "time" "github.com/PaddlePaddle/Paddle/go/connection" @@ -27,9 +26,9 @@ import ( // Client is the client of the master server. type Client struct { - conn *connection.Conn - ch chan record - initChOnce sync.Once + conn *connection.Conn + ch chan record + bufSize int } type record struct { @@ -46,11 +45,7 @@ func WithBuffer(bufSize int) func(*Client) error { if bufSize <= 0 { return nil } - - c.initChOnce.Do(func() { - c.ch = make(chan record, bufSize) - go c.getRecords() - }) + c.bufSize = bufSize return nil } } @@ -104,25 +99,41 @@ func NewClient(opts ...func(*Client) error) (*Client, error) { if err != nil { return nil, err } - } - + c.ch = make(chan record, c.bufSize) + // FIXME: connection is created asyncrosly in monitorMaster go routine, + // ensure the connection is ready for use before calling c.addClient. + time.Sleep(time.Second) return c, nil } -func (c *Client) getRecords() { +// StartGetRecords must be called at beginning of each pass +func (c *Client) StartGetRecords(passID int) { + go c.getRecords(passID) +} + +func (c *Client) getRecords(passID int) { for { - t, err := c.getTask() + t, err := c.getTask(passID) if err != nil { - log.Errorf("Get task failed, sleep 3 seconds and continue, %s", err) - time.Sleep(3 * time.Second) - continue + if err.Error() == ErrPassBefore.Error() || + err.Error() == ErrNoMoreAvailable.Error() || + err.Error() == ErrAllTaskFailed.Error() { + c.ch <- record{nil, err} + break + } + if err.Error() == ErrPassAfter.Error() { + // wait util last pass finishes + time.Sleep(time.Second * 3) + continue + } + log.Errorf("getTask error: %s", err) } for _, chunk := range t.Chunks { - f, err := os.Open(chunk.Path) - if err != nil { - log.Errorln(err) + f, e := os.Open(chunk.Path) + if e != nil { + log.Errorln(e) continue } @@ -178,18 +189,21 @@ func (c *Client) monitorMaster(addrCh <-chan string) { } } -// SetDataset set dataset for the master server to dispatch. +// SetDataset sets dataset to dispatch for the master server. +// +// SetDataset can be call multiple times at one pass. But only the first call +// will be honored. // -// SetDataset can be call multiple times from different nodes. But -// only the first call will be honored. +// After all tasks are done, another call of SetDataset will start another pass. func (c *Client) SetDataset(globPaths []string) error { - return c.conn.Call("Service.SetDataset", globPaths, nil) + err := c.conn.Call("Service.SetDataset", globPaths, nil) + return err } // getTask gets a new task from the master server. -func (c *Client) getTask() (Task, error) { +func (c *Client) getTask(passID int) (Task, error) { var t Task - err := c.conn.Call("Service.GetTask", 0, &t) + err := c.conn.Call("Service.GetTask", passID, &t) return t, err } @@ -208,12 +222,6 @@ func (c *Client) taskFailed(meta TaskMeta) error { // NextRecord will block until the next record is available. It is // thread-safe. func (c *Client) NextRecord() ([]byte, error) { - c.initChOnce.Do(func() { - // initialize with in case WithBuffer is not used. - c.ch = make(chan record, 0) - go c.getRecords() - }) - r := <-c.ch return r.r, r.err } diff --git a/go/master/client_internal_test.go b/go/master/client_internal_test.go index ee305e2c80..d5f3d79464 100644 --- a/go/master/client_internal_test.go +++ b/go/master/client_internal_test.go @@ -54,22 +54,22 @@ func TestGetFinishTask(t *testing.T) { panic(err) } go func(l net.Listener) { - s, err := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) - if err != nil { - panic(err) + s, sErr := NewService(&InMemStore{}, chunkPerTask, time.Second, 1) + if sErr != nil { + panic(sErr) } server := rpc.NewServer() - err = server.Register(s) - if err != nil { - panic(err) + sErr = server.Register(s) + if sErr != nil { + panic(sErr) } mux := http.NewServeMux() mux.Handle(rpc.DefaultRPCPath, server) - err = http.Serve(l, mux) - if err != nil { - panic(err) + sErr = http.Serve(l, mux) + if sErr != nil { + panic(sErr) } }(l) @@ -103,6 +103,7 @@ func TestGetFinishTask(t *testing.T) { ch := make(chan string, 1) ch <- addr go c.monitorMaster(ch) + err = c.SetDataset([]string{path}) if err != nil { panic(err) @@ -111,44 +112,47 @@ func TestGetFinishTask(t *testing.T) { checkOnePass := func(i int) { var tasks []Task for idx := 0; idx < totalTask; idx++ { - task, err := c.getTask() - if err != nil { - t.Fatalf("Error: %v, pass: %d\n", err, i) + task, cErr := c.getTask(i) + if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() { + t.Fatalf("error: %v, pass: %d\n", cErr, i) } tasks = append(tasks, task) } - _, err = c.getTask() - if err == nil { + // getting task before task finishes should return error + _, cErr := c.getTask(i) + if cErr == nil { t.Fatalf("Should get error, pass: %d\n", i) } - err = c.taskFinished(tasks[0].Meta.ID) - if err != nil { - t.Fatalf("Error: %v, pass: %d\n", err, i) + cErr = c.taskFinished(tasks[0].Meta.ID) + if cErr != nil { + t.Fatalf("Error: %v, pass: %d\n", cErr, i) } - - err = c.taskFailed(tasks[0].Meta) - if err != nil { - t.Fatalf("Error: %v, pass: %d\n", err, i) + // call taskFailed once won't put the task to failed queue, just ensure + // the call + cErr = c.taskFailed(tasks[0].Meta) + if cErr != nil { + t.Fatalf("Error: %v, pass: %d\n", cErr, i) } tasks = tasks[1:] - task, err := c.getTask() - if err != nil { - t.Fatal(err) + _, cErr = c.getTask(i) + if cErr != nil && cErr.Error() != ErrNoMoreAvailable.Error() && cErr.Error() != ErrPassAfter.Error() { + t.Fatalf("Should be ErrNoMoreAvailable or ErrPassAfter: %s", cErr) } - tasks = append(tasks, task) for _, task := range tasks { - err = c.taskFinished(task.Meta.ID) - if err != nil { - t.Fatalf("Error: %v, pass: %d\n", err, i) + cErr = c.taskFinished(task.Meta.ID) + if cErr != nil { + t.Fatal(cErr) } } } for i := 0; i < 10; i++ { + // init pass data + c.StartGetRecords(i) checkOnePass(i) } } diff --git a/go/master/client_test.go b/go/master/client_test.go index a3a434ae7e..79b9cc844d 100644 --- a/go/master/client_test.go +++ b/go/master/client_test.go @@ -20,8 +20,10 @@ import ( "net/http" "net/rpc" "os" + "runtime" "strconv" "strings" + "sync" "testing" "time" @@ -29,6 +31,18 @@ import ( "github.com/PaddlePaddle/recordio" ) +// tool function for testing output goroutine ids +func goid() int { + var buf [64]byte + n := runtime.Stack(buf[:], false) + idField := strings.Fields(strings.TrimPrefix(string(buf[:n]), "goroutine "))[0] + id, err := strconv.Atoi(idField) + if err != nil { + panic(fmt.Sprintf("cannot get goroutine id: %v", err)) + } + return id +} + func TestNextRecord(t *testing.T) { const ( path = "/tmp/master_client_TestFull" @@ -45,7 +59,7 @@ func TestNextRecord(t *testing.T) { panic(err) } go func(l net.Listener) { - s, err := master.NewService(&master.InMemStore{}, 10, time.Second, 1) + s, err := master.NewService(&master.InMemStore{}, 1, time.Second*60, 1) if err != nil { panic(err) } @@ -69,7 +83,7 @@ func TestNextRecord(t *testing.T) { panic(err) } - w := recordio.NewWriter(f, -1, -1) + w := recordio.NewWriter(f, 1, -1) for i := 0; i < total; i++ { _, err = w.Write([]byte{byte(i)}) if err != nil { @@ -87,32 +101,49 @@ func TestNextRecord(t *testing.T) { panic(err) } - c, err := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(10)) - if err != nil { - panic(err) - } - - err = c.SetDataset([]string{path}) - if err != nil { - panic(err) - } - - for pass := 0; pass < 50; pass++ { - received := make(map[byte]bool) - for i := 0; i < total; i++ { - r, err := c.NextRecord() - if err != nil { - t.Fatal(pass, i, "Read error:", err) + // start several client to test task fetching + var wg sync.WaitGroup + for i := 0; i < 4; i++ { + wg.Add(1) + // test for multiple concurrent clients + go func() { + defer wg.Done() + // each go-routine needs a single client connection instance + c, e := master.NewClient(master.WithAddr(fmt.Sprintf(":%d", p)), master.WithBuffer(1)) + if e != nil { + t.Fatal(e) } - - if len(r) != 1 { - t.Fatal(pass, i, "Length should be 1.", r) + e = c.SetDataset([]string{path}) + if e != nil { + panic(e) } - - if received[r[0]] { - t.Fatal(pass, i, "Received duplicate.", received, r) + // test for n passes + for pass := 0; pass < 10; pass++ { + c.StartGetRecords(pass) + + received := make(map[byte]bool) + taskid := 0 + for { + r, e := c.NextRecord() + if e != nil { + // ErrorPassAfter will wait, else break for next pass + if e.Error() == master.ErrPassBefore.Error() || + e.Error() == master.ErrNoMoreAvailable.Error() { + break + } + t.Fatal(pass, taskid, "Read error:", e) + } + if len(r) != 1 { + t.Fatal(pass, taskid, "Length should be 1.", r) + } + if received[r[0]] { + t.Fatal(pass, taskid, "Received duplicate.", received, r) + } + taskid++ + received[r[0]] = true + } } - received[r[0]] = true - } + }() } + wg.Wait() } diff --git a/go/master/service.go b/go/master/service.go index d1ec8939e1..1f2112ecfb 100644 --- a/go/master/service.go +++ b/go/master/service.go @@ -19,6 +19,7 @@ import ( "compress/gzip" "encoding/gob" "errors" + "math/rand" "os" "path/filepath" "sync" @@ -33,6 +34,18 @@ const ( dialTimeout = 5 * time.Second ) +// ErrAllTaskFailed occur when tasks are in done or failed state. +var ErrAllTaskFailed = errors.New("all task finished") + +// ErrNoMoreAvailable occur when no task in todo and yet not all done or fail. +var ErrNoMoreAvailable = errors.New("no more available task") + +// ErrPassBefore client side pass number does not match with master counter. +var ErrPassBefore = errors.New("pass number smaller than master") + +// ErrPassAfter client side pass number does not match with master counter. +var ErrPassAfter = errors.New("pass number larger than master") + // Store is the interface for save and load the master state. type Store interface { Save([]byte) error @@ -75,17 +88,26 @@ type Service struct { chunksPerTask int timeoutDur time.Duration failureMax int - ready chan struct{} store Store - mu sync.Mutex - initDone bool - taskQueues taskQueues + ready chan struct{} + initDone bool + + mu sync.Mutex + taskQueues taskQueues + currPass int + jobTasks []taskEntry + savingTrainer string } func partition(chunks []Chunk, chunksPerTask int) []taskEntry { - id := 0 + // generate uniq id across job using nanosecond + randint + counter + // FIXME(typhoonzero): this is a workaround, use uuid + randStart := rand.Int() + counter := 0 + timestamp := time.Now().Nanosecond() + id := timestamp + randStart + counter if chunksPerTask <= 0 { chunksPerTask = 1 } @@ -95,7 +117,8 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { for i, c := range chunks { if i%chunksPerTask == 0 && len(cur.Task.Chunks) > 0 { cur.Task.Meta.ID = id - id++ + counter++ + id = timestamp + randStart + counter result = append(result, cur) cur.Task.Chunks = nil } @@ -266,19 +289,21 @@ func (s *Service) SetDataset(globPaths []string, _ *int) error { return err } - s.taskQueues.Todo = partition(chunks, s.chunksPerTask) + s.jobTasks = partition(chunks, s.chunksPerTask) + s.taskQueues.Todo = s.jobTasks err = s.snapshot() if err != nil { log.Errorln(err) return err } - close(s.ready) s.initDone = true return nil } +// processFailedTask retry s.failureMax times for failed task. +// return true if all task are done or failed. func (s *Service) processFailedTask(t taskEntry, epoch int) { if t.Task.Meta.Epoch != epoch { // new epoch, task launched after the @@ -302,8 +327,9 @@ func (s *Service) processFailedTask(t taskEntry, epoch int) { return } - log.Warningf("Task %v failed %d times, discard.", t.Task, t.NumFailure) + log.Warningf("Task %v failed %d times, re-dispatch.", t.Task, t.NumFailure) s.taskQueues.Todo = append(s.taskQueues.Todo, t) + return } func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() { @@ -331,37 +357,30 @@ func (s *Service) logFields() log.Fields { } // GetTask gets a new task from the service. -func (s *Service) GetTask(_ int, task *Task) error { +// passID is the client side pass count +func (s *Service) GetTask(passID int, task *Task) error { select { case <-s.ready: } s.mu.Lock() defer s.mu.Unlock() + if passID < s.currPass { + return ErrPassBefore + } + if passID > s.currPass { + // Client may get run to pass after master when one client faster than the + // other + return ErrPassAfter + } if len(s.taskQueues.Todo) == 0 { - if len(s.taskQueues.Done) == 0 { - if len(s.taskQueues.Pending) == 0 { - err := errors.New("all task failed") - log.WithFields(s.logFields()).Warningln("All tasks failed.") - return err - } - - // TODO(helin): client need to retry in this - // error case. Gotcha: RPC client can't - // compare returned error with predefined - // errors like io.EOF, because the error - // instance deserialized from RPC is a - // different instance than the error defined - // in package. So we need to figure out a way - // for client to check this error correctly. - err := errors.New("no more available task") - log.WithFields(s.logFields()).Warningln("No more available task.") - return err + if len(s.taskQueues.Done) == 0 && len(s.taskQueues.Pending) == 0 { + log.WithFields(s.logFields()).Warningln("All tasks failed, may start next pass") + return ErrAllTaskFailed } - s.taskQueues.Todo = s.taskQueues.Done - s.taskQueues.Done = nil - log.WithFields(s.logFields()).Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.") + log.WithFields(s.logFields()).Warningln("No more available task.") + return ErrNoMoreAvailable } t := s.taskQueues.Todo[0] @@ -381,7 +400,7 @@ func (s *Service) GetTask(_ int, task *Task) error { } // TaskFinished tell the service that a task is finished. -func (s *Service) TaskFinished(taskID int, _ *int) error { +func (s *Service) TaskFinished(taskID int, dummy *int) error { select { case <-s.ready: } @@ -401,11 +420,14 @@ func (s *Service) TaskFinished(taskID int, _ *int) error { delete(s.taskQueues.Pending, taskID) log.WithFields(s.logFields()).Infof("Task #%d finished.", taskID) - - if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 { - log.WithFields(s.logFields()).Infoln("No more todo and pending task, start a new pass.") - s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...) - s.taskQueues.Done = nil + if len(s.taskQueues.Todo) == 0 && len(s.taskQueues.Pending) == 0 { + // increase master side pass count if all tasks finished + s.currPass++ + s.taskQueues.Todo = s.jobTasks + s.taskQueues.Done = []taskEntry{} + // TODO(typhoonzero): deal with failed tasks + s.taskQueues.Failed = []taskEntry{} + log.WithFields(s.logFields()).Warningf("all task finished, add new pass data, newpass: %d.", s.currPass) } err := s.snapshot() @@ -416,7 +438,7 @@ func (s *Service) TaskFinished(taskID int, _ *int) error { } // TaskFailed tells the service that a task is failed. -func (s *Service) TaskFailed(meta TaskMeta, _ *int) error { +func (s *Service) TaskFailed(meta TaskMeta, dummy *int) error { select { case <-s.ready: } diff --git a/go/master/service_internal_test.go b/go/master/service_internal_test.go index 69a882fc33..bd1a939a55 100644 --- a/go/master/service_internal_test.go +++ b/go/master/service_internal_test.go @@ -44,7 +44,8 @@ func TestPartionIndex(t *testing.T) { cs := make([]Chunk, 100) ts := partition(cs, 20) for i := range ts { - if ts[i].Task.Meta.ID != i { + // test auto increament ids + if i > 0 && ts[i].Task.Meta.ID != ts[i-1].Task.Meta.ID+1 { t.Error(ts[i], i) } } diff --git a/go/pserver/client/c/test/test_train.py b/go/pserver/client/c/test/test_train.py index 17082cf892..85cb399590 100644 --- a/go/pserver/client/c/test/test_train.py +++ b/go/pserver/client/c/test/test_train.py @@ -6,16 +6,19 @@ import cPickle as pickle etcd_ip = os.getenv("MASTER_IP", "127.0.0.1") etcd_endpoint = "http://" + etcd_ip + ":2379" +print "connecting to master, etcd endpoints: ", etcd_endpoint +master_client = master.client(etcd_endpoint, 5, 64) def cloud_reader(): - print "connecting to master, etcd endpoints: ", etcd_endpoint - master_client = master.client(etcd_endpoint, 5, 64) + global master_client master_client.set_dataset( - ["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*-of-*"]) + ["/pfs/dlnel/public/dataset/uci_housing/uci_housing-*"], passes=30) while 1: r, e = master_client.next_record() if not r: + if e != -2: # other errors + print "get record error:", e break yield pickle.loads(r) @@ -27,10 +30,12 @@ def main(): # network config x = paddle.layer.data(name='x', type=paddle.data_type.dense_vector(13)) y_predict = paddle.layer.fc(input=x, - param_attr=paddle.attr.Param(name='w'), + param_attr=paddle.attr.Param( + name='w', learning_rate=1e-3), size=1, act=paddle.activation.Linear(), - bias_attr=paddle.attr.Param(name='b')) + bias_attr=paddle.attr.Param( + name='b', learning_rate=1e-3)) y = paddle.layer.data(name='y', type=paddle.data_type.dense_vector(1)) cost = paddle.layer.mse_cost(input=y_predict, label=y) @@ -40,7 +45,6 @@ def main(): # create optimizer of new remote updater to pserver optimizer = paddle.optimizer.Momentum(momentum=0, learning_rate=1e-3) - print "etcd endoint: ", etcd_endpoint trainer = paddle.trainer.SGD(cost=cost, parameters=parameters, update_equation=optimizer, @@ -51,6 +55,8 @@ def main(): # event_handler to print training and testing info def event_handler(event): if isinstance(event, paddle.event.EndIteration): + # FIXME: for cloud data reader, pass number is managed by master + # should print the server side pass number if event.batch_id % 100 == 0: print "Pass %d, Batch %d, Cost %f" % ( event.pass_id, event.batch_id, event.cost) diff --git a/python/paddle/v2/dataset/common.py b/python/paddle/v2/dataset/common.py index 645f3cc0dc..111496618d 100644 --- a/python/paddle/v2/dataset/common.py +++ b/python/paddle/v2/dataset/common.py @@ -166,55 +166,37 @@ def cluster_files_reader(files_pattern, return reader -def convert(output_path, - reader, - num_shards, - name_prefix, - max_lines_to_shuffle=1000): +def convert(output_path, reader, line_count, name_prefix): import recordio """ Convert data from reader to recordio format files. :param output_path: directory in which output files will be saved. :param reader: a data reader, from which the convert program will read data instances. - :param num_shards: the number of shards that the dataset will be partitioned into. :param name_prefix: the name prefix of generated files. :param max_lines_to_shuffle: the max lines numbers to shuffle before writing. """ - assert num_shards >= 1 - assert max_lines_to_shuffle >= 1 - - def open_writers(): - w = [] - for i in range(0, num_shards): - n = "%s/%s-%05d-of-%05d" % (output_path, name_prefix, i, - num_shards - 1) - w.append(recordio.writer(n)) - - return w - - def close_writers(w): - for i in range(0, num_shards): - w[i].close() + assert line_count >= 1 + indx_f = 0 - def write_data(w, lines): + def write_data(indx_f, lines): random.shuffle(lines) - for i, d in enumerate(lines): + filename = "%s/%s-%05d" % (output_path, name_prefix, indx_f) + writer = recordio.writer(filename) + for l in lines: # FIXME(Yancey1989): # dumps with protocol: pickle.HIGHEST_PROTOCOL - o = pickle.dumps(d) - w[i % num_shards].write(o) + writer.write(cPickle.dumps(l)) + writer.close() - w = open_writers() lines = [] - for i, d in enumerate(reader()): lines.append(d) - if i % max_lines_to_shuffle == 0 and i >= max_lines_to_shuffle: - write_data(w, lines) + if i % line_count == 0 and i >= line_count: + write_data(indx_f, lines) lines = [] + indx_f += 1 continue - write_data(w, lines) - close_writers(w) + write_data(indx_f, lines) diff --git a/python/paddle/v2/master/client.py b/python/paddle/v2/master/client.py index 3ac62d116b..b658a81630 100644 --- a/python/paddle/v2/master/client.py +++ b/python/paddle/v2/master/client.py @@ -49,7 +49,6 @@ class client(object): def set_dataset(self, paths): holder_type = ctypes.c_char_p * len(paths) holder = holder_type() - print paths for idx, path in enumerate(paths): c_ptr = ctypes.c_char_p(path) holder[idx] = c_ptr -- GitLab