Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0fa40924
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0fa40924
编写于
6月 29, 2017
作者:
G
gongweibao
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs
上级
4874810b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
47 addition
and
16 deletion
+47
-16
go/master/c/client.go
go/master/c/client.go
+16
-2
go/master/client.go
go/master/client.go
+15
-6
go/master/client_test.go
go/master/client_test.go
+14
-4
python/paddle/v2/reader/creator.py
python/paddle/v2/reader/creator.py
+2
-4
未找到文件。
go/master/c/client.go
浏览文件 @
0fa40924
...
...
@@ -13,6 +13,7 @@ typedef int paddle_master_client;
import
"C"
import
(
"io"
"sync"
"unsafe"
...
...
@@ -84,14 +85,27 @@ func paddle_set_dataset(client C.paddle_master_client, path **C.char, size C.int
return
C
.
PADDLE_MASTER_OK
}
// return value:
// 0:ok
// -1:EOF
// -2:error
//export paddle_next_record
func
paddle_next_record
(
client
C
.
paddle_master_client
,
record
**
C
.
uchar
)
C
.
int
{
c
:=
get
(
client
)
r
:=
c
.
NextRecord
()
if
r
==
nil
{
r
,
err
:=
c
.
NextRecord
()
if
err
==
io
.
EOF
{
// EOF
*
record
=
(
*
C
.
uchar
)(
nullPtr
)
return
-
1
}
if
err
!=
nil
{
// Error
// TODO: return the type of error?
*
record
=
(
*
C
.
uchar
)(
nullPtr
)
return
-
2
}
if
len
(
r
)
==
0
{
// Empty record
*
record
=
(
*
C
.
uchar
)(
nullPtr
)
...
...
go/master/client.go
浏览文件 @
0fa40924
package
master
import
(
"io"
"os"
"time"
...
...
@@ -17,7 +18,12 @@ type Addresser interface {
// Client is the client of the master server.
type
Client
struct
{
conn
*
connection
.
Conn
ch
chan
[]
byte
ch
chan
record
}
type
record
struct
{
r
[]
byte
err
error
}
// NewClient creates a new Client.
...
...
@@ -27,7 +33,7 @@ type Client struct {
func
NewClient
(
addr
Addresser
,
bufSize
int
)
*
Client
{
c
:=
&
Client
{}
c
.
conn
=
connection
.
New
()
c
.
ch
=
make
(
chan
[]
byte
,
bufSize
)
c
.
ch
=
make
(
chan
record
,
bufSize
)
go
c
.
monitorMaster
(
addr
)
go
c
.
getRecords
()
return
c
...
...
@@ -52,18 +58,20 @@ func (c *Client) getRecords() {
s
:=
recordio
.
NewRangeScanner
(
f
,
&
chunk
.
Index
,
-
1
,
-
1
)
for
s
.
Scan
()
{
c
.
ch
<-
s
.
Record
()
c
.
ch
<-
record
{
s
.
Record
(),
nil
}
}
if
s
.
Err
()
!=
nil
{
c
.
ch
<-
record
{
nil
,
s
.
Err
()}
log
.
Errorln
(
err
,
chunk
.
Path
)
}
err
=
f
.
Close
()
c
.
ch
<-
nil
if
err
!=
nil
{
log
.
Errorln
(
err
)
}
c
.
ch
<-
record
{
nil
,
io
.
EOF
}
}
// We treat a task as finished whenever the last data
...
...
@@ -133,6 +141,7 @@ func (c *Client) taskFinished(taskID int) error {
//
// NextRecord will block until the next record is available. It is
// thread-safe.
func
(
c
*
Client
)
NextRecord
()
[]
byte
{
return
<-
c
.
ch
func
(
c
*
Client
)
NextRecord
()
([]
byte
,
error
)
{
r
:=
<-
c
.
ch
return
r
.
r
,
r
.
err
}
go/master/client_test.go
浏览文件 @
0fa40924
...
...
@@ -2,6 +2,7 @@ package master_test
import
(
"fmt"
"io"
"net"
"net/http"
"net/rpc"
...
...
@@ -69,13 +70,22 @@ func TestNextRecord(t *testing.T) {
for
pass
:=
0
;
pass
<
50
;
pass
++
{
received
:=
make
(
map
[
byte
]
bool
)
for
i
:=
0
;
i
<
total
;
i
++
{
r
:=
c
.
NextRecord
()
for
i
:=
0
;
i
<=
total
;
i
++
{
r
,
err
:=
c
.
NextRecord
()
if
err
==
io
.
EOF
{
break
}
if
err
!=
nil
{
t
.
Fatal
(
pass
,
i
,
"Read error:"
,
err
)
}
if
len
(
r
)
!=
1
{
t
.
Fatal
(
"Length should be 1."
,
r
)
t
.
Fatal
(
pass
,
i
,
"Length should be 1."
,
r
)
}
if
received
[
r
[
0
]]
{
t
.
Fatal
(
"Received duplicate."
,
received
,
r
)
t
.
Fatal
(
pass
,
i
,
"Received duplicate."
,
received
,
r
)
}
received
[
r
[
0
]]
=
true
}
...
...
python/paddle/v2/reader/creator.py
浏览文件 @
0fa40924
...
...
@@ -79,7 +79,6 @@ def recordio_local(paths):
return
reader
def
recordio
(
paths
,
addr
=
""
,
buf_size
=
100
):
"""
Creates a data reader that outputs record one one by one
...
...
@@ -90,8 +89,8 @@ def recordio(paths, addr="", buf_size=100):
import
os
import
paddle.v2.master.client
as
cloud
if
len
(
os
.
environ
[
"KUBERNETES_SERVICE_HOST"
])
==
0
:
return
recordio_local
(
path
)
if
"KUBERNETES_SERVICE_HOST"
not
in
os
.
environ
.
keys
()
:
return
recordio_local
(
path
s
)
def
reader
():
c
=
cloud
(
addr
,
buf_size
)
...
...
@@ -106,4 +105,3 @@ def recordio(paths, addr="", buf_size=100):
c
.
close
()
return
reader
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录