提交 8c9119af 编写于 作者: G gongweibao 提交者: GitHub

add logs and fix a bug (#5074)

add logs and fix a python path bug
上级 c1fd1dc7
...@@ -123,7 +123,8 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int ...@@ -123,7 +123,8 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
} }
err := c.SetDataset(paths) err := c.SetDataset(paths)
if err != nil { if err != nil {
log.Error("error set dataset", log.Ctx{"error": err}) log.Error("error set dataset",
log.Ctx{"error": err, "paths": paths})
return C.PADDLE_MASTER_ERROR return C.PADDLE_MASTER_ERROR
} }
......
...@@ -121,6 +121,7 @@ func (c *Client) StartGetRecords(passID int) { ...@@ -121,6 +121,7 @@ func (c *Client) StartGetRecords(passID int) {
} }
func (c *Client) getRecords(passID int) { func (c *Client) getRecords(passID int) {
i := 0
for { for {
t, err := c.getTask(passID) t, err := c.getTask(passID)
if err != nil { if err != nil {
...@@ -130,13 +131,21 @@ func (c *Client) getRecords(passID int) { ...@@ -130,13 +131,21 @@ func (c *Client) getRecords(passID int) {
c.ch <- record{nil, err} c.ch <- record{nil, err}
break break
} }
if err.Error() == ErrPassAfter.Error() {
if i%60 == 0 {
log.Debug("getTask of passID error.",
log.Ctx{"error": err, "passID": passID})
i = 0
}
// if err.Error() == ErrPassAfter.Error()
// wait util last pass finishes // wait util last pass finishes
// if other error such as network error
// wait to reconnect or task time out
time.Sleep(time.Second * 3) time.Sleep(time.Second * 3)
i += 3
continue continue
} }
log.Error("getTask error.", log.Ctx{"error": err})
}
for _, chunk := range t.Chunks { for _, chunk := range t.Chunks {
f, e := os.Open(chunk.Path) f, e := os.Open(chunk.Path)
......
...@@ -117,6 +117,7 @@ func TestNextRecord(t *testing.T) { ...@@ -117,6 +117,7 @@ func TestNextRecord(t *testing.T) {
if e != nil { if e != nil {
panic(e) panic(e)
} }
// test for n passes // test for n passes
for pass := 0; pass < 10; pass++ { for pass := 0; pass < 10; pass++ {
c.StartGetRecords(pass) c.StartGetRecords(pass)
......
...@@ -61,7 +61,7 @@ def recordio(paths, buf_size=100): ...@@ -61,7 +61,7 @@ def recordio(paths, buf_size=100):
""" """
Creates a data reader from given RecordIO file paths separated by ",", Creates a data reader from given RecordIO file paths separated by ",",
glob pattern is supported. glob pattern is supported.
:path: path of recordio files. :path: path of recordio files, can be a string or a string list.
:returns: data reader of recordio files. :returns: data reader of recordio files.
""" """
...@@ -92,7 +92,7 @@ def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64): ...@@ -92,7 +92,7 @@ def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64):
""" """
Create a data reader that yield a record one by one from Create a data reader that yield a record one by one from
the paths: the paths:
:path: path of recordio files. :paths: path of recordio files, can be a string or a string list.
:etcd_endpoints: the endpoints for etcd cluster :etcd_endpoints: the endpoints for etcd cluster
:returns: data reader of recordio files. :returns: data reader of recordio files.
...@@ -107,7 +107,12 @@ def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64): ...@@ -107,7 +107,12 @@ def cloud_reader(paths, etcd_endpoints, timeout_sec=5, buf_size=64):
import cPickle as pickle import cPickle as pickle
import paddle.v2.master as master import paddle.v2.master as master
c = master.client(etcd_endpoints, timeout_sec, buf_size) c = master.client(etcd_endpoints, timeout_sec, buf_size)
c.set_dataset(paths)
if isinstance(paths, basestring):
path = [paths]
else:
path = paths
c.set_dataset(path)
def reader(): def reader():
global pass_num global pass_num
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册