Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
fa5c3f1f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fa5c3f1f
编写于
6月 14, 2017
作者:
H
Helin Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement master client, Go part
上级
91f82aba
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
290 addition
and
75 deletion
+290
-75
go/master/c/client.go
go/master/c/client.go
+81
-0
go/master/client.go
go/master/client.go
+49
-4
go/master/client_internal_test.go
go/master/client_internal_test.go
+120
-0
go/master/client_test.go
go/master/client_test.go
+22
-63
go/master/service.go
go/master/service.go
+18
-8
未找到文件。
go/master/c/client.go
0 → 100644
浏览文件 @
fa5c3f1f
package
main
/*
typedef int paddle_master_client;
*/
import
"C"
import
(
"log"
"sync"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
)
var
mu
sync
.
Mutex
var
handleMap
=
make
(
map
[
C
.
paddle_master_client
]
*
master
.
Client
)
var
curHandle
C
.
paddle_master_client
func
add
(
c
*
master
.
Client
)
C
.
paddle_master_client
{
mu
.
Lock
()
defer
mu
.
Unlock
()
client
:=
curHandle
curHandle
++
handleMap
[
client
]
=
c
return
client
}
func
get
(
client
C
.
paddle_master_client
)
*
master
.
Client
{
mu
.
Lock
()
defer
mu
.
Unlock
()
return
handleMap
[
client
]
}
func
remove
(
client
C
.
paddle_master_client
)
*
master
.
Client
{
mu
.
Lock
()
defer
mu
.
Unlock
()
h
:=
handleMap
[
client
]
delete
(
handleMap
,
client
)
return
h
}
type
addresser
string
func
(
a
addresser
)
Address
()
string
{
return
string
(
a
)
}
//paddle_new_master_client
func
paddle_new_master_client
(
addr
*
C
.
char
,
buf_size
C
.
int
)
C
.
paddle_master_client
{
a
:=
C
.
GoString
(
addr
)
c
:=
master
.
NewClient
(
addresser
(
a
),
int
(
buf_size
))
return
add
(
c
)
}
//export paddle_new_etcd_master_client
func
paddle_new_etcd_master_client
(
etcd_addr
*
C
.
char
)
C
.
paddle_master_client
{
// TODO(helin): fault tolerant master client using etcd.
panic
(
"not implemented."
)
}
//export paddle_set_dataset
func
paddle_set_dataset
(
client
C
.
paddle_master_client
,
path
**
C
.
char
,
size
C
.
int
)
C
.
int
{
c
:=
get
(
client
)
var
paths
[]
string
for
i
:=
0
;
i
<
int
(
size
);
i
++
{
ptr
:=
(
**
C
.
char
)(
unsafe
.
Pointer
(
uintptr
(
unsafe
.
Pointer
(
path
))
+
uintptr
(
size
)))
str
:=
C
.
GoString
(
*
ptr
)
paths
=
append
(
paths
,
str
)
}
err
:=
c
.
SetDataset
(
paths
)
if
err
!=
nil
{
log
.
Println
(
err
)
return
-
1
}
return
0
}
func
main
()
{}
go/master/client.go
浏览文件 @
fa5c3f1f
...
...
@@ -2,9 +2,11 @@ package master
import
(
"log"
"os"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
// Addresser provide the address of the master server.
...
...
@@ -15,16 +17,51 @@ type Addresser interface {
// Client is the client of the master server.
type
Client
struct
{
conn
*
connection
.
Conn
ch
chan
[]
byte
}
// NewClient creates a new Client.
func
NewClient
(
addr
Addresser
)
*
Client
{
//
// bufSize is the record buffer size. NextRecord will read from the
// buffer.
func
NewClient
(
addr
Addresser
,
bufSize
int
)
*
Client
{
c
:=
&
Client
{}
c
.
conn
=
connection
.
New
()
c
.
ch
=
make
(
chan
[]
byte
,
bufSize
)
go
c
.
monitorMaster
(
addr
)
go
c
.
getRecords
()
return
c
}
func
(
c
*
Client
)
getRecords
()
{
for
{
t
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
log
.
Println
(
err
)
continue
}
for
_
,
chunk
:=
range
t
.
Chunks
{
f
,
err
:=
os
.
Open
(
chunk
.
Path
)
if
err
!=
nil
{
log
.
Println
(
err
)
continue
}
s
:=
recordio
.
NewRangeScanner
(
f
,
&
chunk
.
Index
,
-
1
,
-
1
)
for
s
.
Scan
()
{
c
.
ch
<-
s
.
Record
()
}
err
=
f
.
Close
()
if
err
!=
nil
{
log
.
Println
(
err
)
}
}
c
.
taskFinished
(
t
.
ID
)
}
}
func
(
c
*
Client
)
monitorMaster
(
addr
Addresser
)
{
lastMaster
:=
""
monitor
:=
func
()
{
...
...
@@ -69,14 +106,22 @@ func (c *Client) SetDataset(globPaths []string) error {
return
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
}
//
G
etTask gets a new task from the master server.
func
(
c
*
Client
)
G
etTask
()
(
Task
,
error
)
{
//
g
etTask gets a new task from the master server.
func
(
c
*
Client
)
g
etTask
()
(
Task
,
error
)
{
var
t
Task
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
0
,
&
t
)
return
t
,
err
}
// TaskFinished tells the master server a task is finished.
func
(
c
*
Client
)
T
askFinished
(
taskID
int
)
error
{
func
(
c
*
Client
)
t
askFinished
(
taskID
int
)
error
{
return
c
.
conn
.
Call
(
"Service.TaskFinished"
,
taskID
,
nil
)
}
// NextRecord returns next record in the dataset.
//
// NextRecord will block until next record is available. It is
// thread-safe.
func
(
c
*
Client
)
NextRecord
()
[]
byte
{
return
<-
c
.
ch
}
go/master/client_internal_test.go
0 → 100644
浏览文件 @
fa5c3f1f
package
master
import
(
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
log
"github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
const
(
totalTask
=
20
chunkPerTask
=
10
)
func
init
()
{
log
.
SetLevel
(
log
.
ErrorLevel
)
}
type
TestAddresser
string
func
(
a
TestAddresser
)
Address
()
string
{
return
string
(
a
)
}
func
TestGetFinishTask
(
t
*
testing
.
T
)
{
const
path
=
"/tmp/master_client_test_0"
l
,
err
:=
net
.
Listen
(
"tcp"
,
":0"
)
if
err
!=
nil
{
panic
(
err
)
}
ss
:=
strings
.
Split
(
l
.
Addr
()
.
String
(),
":"
)
p
,
err
:=
strconv
.
Atoi
(
ss
[
len
(
ss
)
-
1
])
if
err
!=
nil
{
panic
(
err
)
}
go
func
(
l
net
.
Listener
)
{
s
:=
NewService
(
chunkPerTask
,
time
.
Second
,
1
)
server
:=
rpc
.
NewServer
()
err
:=
server
.
Register
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
mux
:=
http
.
NewServeMux
()
mux
.
Handle
(
rpc
.
DefaultRPCPath
,
server
)
err
=
http
.
Serve
(
l
,
mux
)
if
err
!=
nil
{
panic
(
err
)
}
}(
l
)
f
,
err
:=
os
.
Create
(
path
)
if
err
!=
nil
{
panic
(
err
)
}
for
i
:=
0
;
i
<
totalTask
*
chunkPerTask
;
i
++
{
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
.
Write
(
nil
)
// call Close to force RecordIO writing a chunk.
w
.
Close
()
}
f
.
Close
()
c
:=
&
Client
{}
c
.
conn
=
connection
.
New
()
go
c
.
monitorMaster
(
TestAddresser
(
fmt
.
Sprintf
(
":%d"
,
p
)))
c
.
SetDataset
([]
string
{
path
})
checkOnePass
:=
func
(
i
int
)
{
var
tasks
[]
Task
for
idx
:=
0
;
idx
<
totalTask
;
idx
++
{
task
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
t
.
Fatal
(
err
,
" pass:"
,
i
)
}
tasks
=
append
(
tasks
,
task
)
}
_
,
err
=
c
.
getTask
()
if
err
==
nil
{
t
.
Fatal
(
"Should get error. Pass:"
,
i
)
}
err
=
c
.
taskFinished
(
tasks
[
0
]
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
err
,
"pass:"
,
i
)
}
tasks
=
tasks
[
1
:
]
task
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
tasks
=
append
(
tasks
,
task
)
for
_
,
task
:=
range
tasks
{
err
=
c
.
taskFinished
(
task
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
err
,
" pass:"
,
i
)
}
}
}
for
i
:=
0
;
i
<
10
;
i
++
{
checkOnePass
(
i
)
}
}
go/master/client_test.go
浏览文件 @
fa5c3f1f
...
...
@@ -11,21 +11,15 @@ import (
"testing"
"time"
log
"github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
const
(
totalTask
=
20
chunkPerTask
=
10
)
var
port
int
func
init
()
{
log
.
SetLevel
(
log
.
ErrorLevel
)
func
TestNextRecord
(
t
*
testing
.
T
)
{
const
(
path
=
"/tmp/master_client_TestFull"
total
=
50
)
l
,
err
:=
net
.
Listen
(
"tcp"
,
":0"
)
if
err
!=
nil
{
...
...
@@ -37,10 +31,9 @@ func init() {
if
err
!=
nil
{
panic
(
err
)
}
port
=
p
go
func
(
l
net
.
Listener
)
{
s
:=
master
.
NewService
(
chunkPerTask
,
time
.
Second
,
1
)
s
:=
master
.
NewService
(
10
,
time
.
Second
,
1
)
server
:=
rpc
.
NewServer
()
err
:=
server
.
Register
(
s
)
if
err
!=
nil
{
...
...
@@ -54,67 +47,33 @@ func init() {
panic
(
err
)
}
}(
l
)
}
type
addresser
string
func
(
a
addresser
)
Address
()
string
{
return
string
(
a
)
}
func
TestClientFull
(
t
*
testing
.
T
)
{
const
p
=
"/tmp/master_client_test_0"
f
,
err
:=
os
.
Create
(
p
)
f
,
err
:=
os
.
Create
(
path
)
if
err
!=
nil
{
panic
(
err
)
}
for
i
:=
0
;
i
<
totalTask
*
chunkPerTask
;
i
++
{
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
.
Write
(
nil
)
// call Close to force RecordIO writing a chunk.
w
.
Close
()
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
for
i
:=
0
;
i
<
total
;
i
++
{
w
.
Write
([]
byte
{
byte
(
i
)})
}
w
.
Close
()
f
.
Close
()
c
:=
master
.
NewClient
(
addresser
(
fmt
.
Sprintf
(
":%d"
,
port
))
)
c
.
SetDataset
([]
string
{
p
})
c
:=
master
.
NewClient
(
master
.
TestAddresser
(
fmt
.
Sprintf
(
":%d"
,
p
)),
10
)
c
.
SetDataset
([]
string
{
p
ath
})
checkOnePass
:=
func
(
i
int
)
{
var
tasks
[]
master
.
Task
for
i
:=
0
;
i
<
total
Task
;
i
++
{
task
,
err
:=
c
.
GetTask
()
if
err
!=
nil
{
t
.
Fatal
(
i
,
er
r
)
for
pass
:=
0
;
pass
<
50
;
pass
++
{
received
:=
make
(
map
[
byte
]
bool
)
for
i
:=
0
;
i
<
total
;
i
++
{
r
:=
c
.
NextRecord
()
if
len
(
r
)
!=
1
{
t
.
Fatal
(
"Length should be 1."
,
r
)
}
tasks
=
append
(
tasks
,
task
)
}
_
,
err
=
c
.
GetTask
()
if
err
==
nil
{
t
.
Fatal
(
i
,
"should get error."
)
}
err
=
c
.
TaskFinished
(
tasks
[
0
]
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
tasks
=
tasks
[
1
:
]
task
,
err
:=
c
.
GetTask
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
tasks
=
append
(
tasks
,
task
)
for
_
,
task
:=
range
tasks
{
err
=
c
.
TaskFinished
(
task
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
i
,
err
)
if
received
[
r
[
0
]]
{
t
.
Fatal
(
"Received duplicate."
,
received
,
r
)
}
received
[
r
[
0
]]
=
true
}
}
for
i
:=
0
;
i
<
10
;
i
++
{
checkOnePass
(
i
)
}
}
go/master/service.go
浏览文件 @
fa5c3f1f
...
...
@@ -217,6 +217,16 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
}
}
// must be called with lock held.
func
(
s
*
Service
)
logFields
()
log
.
Fields
{
return
log
.
Fields
{
"todoLen"
:
len
(
s
.
taskQueues
.
Todo
),
"pendingLen"
:
len
(
s
.
taskQueues
.
Pending
),
"doneLen"
:
len
(
s
.
taskQueues
.
Done
),
"failedLen"
:
len
(
s
.
taskQueues
.
Failed
),
}
}
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
select
{
...
...
@@ -230,7 +240,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
if
len
(
s
.
taskQueues
.
Done
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
err
:=
errors
.
New
(
"all task failed"
)
log
.
W
arningln
(
err
)
log
.
W
ithFields
(
s
.
logFields
())
.
Warningln
(
"All tasks failed."
)
return
err
}
...
...
@@ -243,12 +253,12 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// in package. So we need to figure out a way
// for client to check this error correctly.
err
:=
errors
.
New
(
"no more available task"
)
log
.
W
arningln
(
err
)
log
.
W
ithFields
(
s
.
logFields
())
.
Warningln
(
"No more available task."
)
return
err
}
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
s
.
taskQueues
.
Done
=
nil
log
.
Infoln
(
"No more todo task, but trainer is requesting task to do. Move all done task to todo."
)
log
.
WithFields
(
s
.
logFields
())
.
Infoln
(
"No more todo task, but trainer is requesting task to do. Move all done task to todo."
)
}
t
:=
s
.
taskQueues
.
Todo
[
0
]
...
...
@@ -261,7 +271,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
*
task
=
t
.
Task
log
.
Infof
(
"Task #%d dispatched
\n
"
,
task
.
ID
)
log
.
WithFields
(
s
.
logFields
())
.
Infof
(
"Task #%d dispatched.
"
,
task
.
ID
)
time
.
AfterFunc
(
s
.
timeoutDur
,
s
.
checkTimeoutFunc
(
t
.
Task
.
ID
,
t
.
Epoch
))
return
nil
...
...
@@ -276,12 +286,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
log
.
Infof
(
"Task %d finished
\n
"
,
taskID
)
t
,
ok
:=
s
.
taskQueues
.
Pending
[
taskID
]
if
!
ok
{
err
:=
errors
.
New
(
"pending task not found"
)
log
.
W
arningln
(
err
)
log
.
W
ithFields
(
s
.
logFields
())
.
Warningln
(
"Pending task #%d not found."
,
taskID
)
return
err
}
...
...
@@ -290,8 +298,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s
.
taskQueues
.
Done
=
append
(
s
.
taskQueues
.
Done
,
t
)
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
log
.
WithFields
(
s
.
logFields
())
.
Infof
(
"Task #%d finished."
,
taskID
)
if
len
(
s
.
taskQueues
.
Pending
)
==
0
&&
len
(
s
.
taskQueues
.
Todo
)
==
0
{
log
.
Infoln
(
"No more todo and pending task, start a new pass."
)
log
.
WithFields
(
s
.
logFields
())
.
Infoln
(
"No more todo and pending task, start a new pass."
)
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
s
.
taskQueues
.
Done
...
)
s
.
taskQueues
.
Done
=
nil
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录