Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fa5c3f1f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fa5c3f1f
编写于
6月 14, 2017
作者:
H
Helin Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement master client, Go part
上级
91f82aba
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
290 addition
and
75 deletion
+290
-75
go/master/c/client.go
go/master/c/client.go
+81
-0
go/master/client.go
go/master/client.go
+49
-4
go/master/client_internal_test.go
go/master/client_internal_test.go
+120
-0
go/master/client_test.go
go/master/client_test.go
+22
-63
go/master/service.go
go/master/service.go
+18
-8
未找到文件。
go/master/c/client.go
0 → 100644
浏览文件 @
fa5c3f1f
package
main
/*
typedef int paddle_master_client;
*/
import
"C"
import
(
"log"
"sync"
"unsafe"
"github.com/PaddlePaddle/Paddle/go/master"
)
var
mu
sync
.
Mutex
var
handleMap
=
make
(
map
[
C
.
paddle_master_client
]
*
master
.
Client
)
var
curHandle
C
.
paddle_master_client
func
add
(
c
*
master
.
Client
)
C
.
paddle_master_client
{
mu
.
Lock
()
defer
mu
.
Unlock
()
client
:=
curHandle
curHandle
++
handleMap
[
client
]
=
c
return
client
}
func
get
(
client
C
.
paddle_master_client
)
*
master
.
Client
{
mu
.
Lock
()
defer
mu
.
Unlock
()
return
handleMap
[
client
]
}
func
remove
(
client
C
.
paddle_master_client
)
*
master
.
Client
{
mu
.
Lock
()
defer
mu
.
Unlock
()
h
:=
handleMap
[
client
]
delete
(
handleMap
,
client
)
return
h
}
type
addresser
string
func
(
a
addresser
)
Address
()
string
{
return
string
(
a
)
}
//paddle_new_master_client
func
paddle_new_master_client
(
addr
*
C
.
char
,
buf_size
C
.
int
)
C
.
paddle_master_client
{
a
:=
C
.
GoString
(
addr
)
c
:=
master
.
NewClient
(
addresser
(
a
),
int
(
buf_size
))
return
add
(
c
)
}
//export paddle_new_etcd_master_client
func
paddle_new_etcd_master_client
(
etcd_addr
*
C
.
char
)
C
.
paddle_master_client
{
// TODO(helin): fault tolerant master client using etcd.
panic
(
"not implemented."
)
}
//export paddle_set_dataset
func
paddle_set_dataset
(
client
C
.
paddle_master_client
,
path
**
C
.
char
,
size
C
.
int
)
C
.
int
{
c
:=
get
(
client
)
var
paths
[]
string
for
i
:=
0
;
i
<
int
(
size
);
i
++
{
ptr
:=
(
**
C
.
char
)(
unsafe
.
Pointer
(
uintptr
(
unsafe
.
Pointer
(
path
))
+
uintptr
(
size
)))
str
:=
C
.
GoString
(
*
ptr
)
paths
=
append
(
paths
,
str
)
}
err
:=
c
.
SetDataset
(
paths
)
if
err
!=
nil
{
log
.
Println
(
err
)
return
-
1
}
return
0
}
func
main
()
{}
go/master/client.go
浏览文件 @
fa5c3f1f
...
@@ -2,9 +2,11 @@ package master
...
@@ -2,9 +2,11 @@ package master
import
(
import
(
"log"
"log"
"os"
"time"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
)
// Addresser provide the address of the master server.
// Addresser provide the address of the master server.
...
@@ -15,16 +17,51 @@ type Addresser interface {
...
@@ -15,16 +17,51 @@ type Addresser interface {
// Client is the client of the master server.
// Client is the client of the master server.
type
Client
struct
{
type
Client
struct
{
conn
*
connection
.
Conn
conn
*
connection
.
Conn
ch
chan
[]
byte
}
}
// NewClient creates a new Client.
// NewClient creates a new Client.
func
NewClient
(
addr
Addresser
)
*
Client
{
//
// bufSize is the record buffer size. NextRecord will read from the
// buffer.
func
NewClient
(
addr
Addresser
,
bufSize
int
)
*
Client
{
c
:=
&
Client
{}
c
:=
&
Client
{}
c
.
conn
=
connection
.
New
()
c
.
conn
=
connection
.
New
()
c
.
ch
=
make
(
chan
[]
byte
,
bufSize
)
go
c
.
monitorMaster
(
addr
)
go
c
.
monitorMaster
(
addr
)
go
c
.
getRecords
()
return
c
return
c
}
}
func
(
c
*
Client
)
getRecords
()
{
for
{
t
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
log
.
Println
(
err
)
continue
}
for
_
,
chunk
:=
range
t
.
Chunks
{
f
,
err
:=
os
.
Open
(
chunk
.
Path
)
if
err
!=
nil
{
log
.
Println
(
err
)
continue
}
s
:=
recordio
.
NewRangeScanner
(
f
,
&
chunk
.
Index
,
-
1
,
-
1
)
for
s
.
Scan
()
{
c
.
ch
<-
s
.
Record
()
}
err
=
f
.
Close
()
if
err
!=
nil
{
log
.
Println
(
err
)
}
}
c
.
taskFinished
(
t
.
ID
)
}
}
func
(
c
*
Client
)
monitorMaster
(
addr
Addresser
)
{
func
(
c
*
Client
)
monitorMaster
(
addr
Addresser
)
{
lastMaster
:=
""
lastMaster
:=
""
monitor
:=
func
()
{
monitor
:=
func
()
{
...
@@ -69,14 +106,22 @@ func (c *Client) SetDataset(globPaths []string) error {
...
@@ -69,14 +106,22 @@ func (c *Client) SetDataset(globPaths []string) error {
return
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
return
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
}
}
//
G
etTask gets a new task from the master server.
//
g
etTask gets a new task from the master server.
func
(
c
*
Client
)
G
etTask
()
(
Task
,
error
)
{
func
(
c
*
Client
)
g
etTask
()
(
Task
,
error
)
{
var
t
Task
var
t
Task
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
0
,
&
t
)
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
0
,
&
t
)
return
t
,
err
return
t
,
err
}
}
// TaskFinished tells the master server a task is finished.
// TaskFinished tells the master server a task is finished.
func
(
c
*
Client
)
T
askFinished
(
taskID
int
)
error
{
func
(
c
*
Client
)
t
askFinished
(
taskID
int
)
error
{
return
c
.
conn
.
Call
(
"Service.TaskFinished"
,
taskID
,
nil
)
return
c
.
conn
.
Call
(
"Service.TaskFinished"
,
taskID
,
nil
)
}
}
// NextRecord returns next record in the dataset.
//
// NextRecord will block until next record is available. It is
// thread-safe.
func
(
c
*
Client
)
NextRecord
()
[]
byte
{
return
<-
c
.
ch
}
go/master/client_internal_test.go
0 → 100644
浏览文件 @
fa5c3f1f
package
master
import
(
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
log
"github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/connection"
"github.com/PaddlePaddle/recordio"
)
const
(
totalTask
=
20
chunkPerTask
=
10
)
func
init
()
{
log
.
SetLevel
(
log
.
ErrorLevel
)
}
type
TestAddresser
string
func
(
a
TestAddresser
)
Address
()
string
{
return
string
(
a
)
}
func
TestGetFinishTask
(
t
*
testing
.
T
)
{
const
path
=
"/tmp/master_client_test_0"
l
,
err
:=
net
.
Listen
(
"tcp"
,
":0"
)
if
err
!=
nil
{
panic
(
err
)
}
ss
:=
strings
.
Split
(
l
.
Addr
()
.
String
(),
":"
)
p
,
err
:=
strconv
.
Atoi
(
ss
[
len
(
ss
)
-
1
])
if
err
!=
nil
{
panic
(
err
)
}
go
func
(
l
net
.
Listener
)
{
s
:=
NewService
(
chunkPerTask
,
time
.
Second
,
1
)
server
:=
rpc
.
NewServer
()
err
:=
server
.
Register
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
mux
:=
http
.
NewServeMux
()
mux
.
Handle
(
rpc
.
DefaultRPCPath
,
server
)
err
=
http
.
Serve
(
l
,
mux
)
if
err
!=
nil
{
panic
(
err
)
}
}(
l
)
f
,
err
:=
os
.
Create
(
path
)
if
err
!=
nil
{
panic
(
err
)
}
for
i
:=
0
;
i
<
totalTask
*
chunkPerTask
;
i
++
{
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
.
Write
(
nil
)
// call Close to force RecordIO writing a chunk.
w
.
Close
()
}
f
.
Close
()
c
:=
&
Client
{}
c
.
conn
=
connection
.
New
()
go
c
.
monitorMaster
(
TestAddresser
(
fmt
.
Sprintf
(
":%d"
,
p
)))
c
.
SetDataset
([]
string
{
path
})
checkOnePass
:=
func
(
i
int
)
{
var
tasks
[]
Task
for
idx
:=
0
;
idx
<
totalTask
;
idx
++
{
task
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
t
.
Fatal
(
err
,
" pass:"
,
i
)
}
tasks
=
append
(
tasks
,
task
)
}
_
,
err
=
c
.
getTask
()
if
err
==
nil
{
t
.
Fatal
(
"Should get error. Pass:"
,
i
)
}
err
=
c
.
taskFinished
(
tasks
[
0
]
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
err
,
"pass:"
,
i
)
}
tasks
=
tasks
[
1
:
]
task
,
err
:=
c
.
getTask
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
tasks
=
append
(
tasks
,
task
)
for
_
,
task
:=
range
tasks
{
err
=
c
.
taskFinished
(
task
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
err
,
" pass:"
,
i
)
}
}
}
for
i
:=
0
;
i
<
10
;
i
++
{
checkOnePass
(
i
)
}
}
go/master/client_test.go
浏览文件 @
fa5c3f1f
...
@@ -11,21 +11,15 @@ import (
...
@@ -11,21 +11,15 @@ import (
"testing"
"testing"
"time"
"time"
log
"github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
"github.com/PaddlePaddle/recordio"
)
)
const
(
func
TestNextRecord
(
t
*
testing
.
T
)
{
totalTask
=
20
const
(
chunkPerTask
=
10
path
=
"/tmp/master_client_TestFull"
)
total
=
50
)
var
port
int
func
init
()
{
log
.
SetLevel
(
log
.
ErrorLevel
)
l
,
err
:=
net
.
Listen
(
"tcp"
,
":0"
)
l
,
err
:=
net
.
Listen
(
"tcp"
,
":0"
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -37,10 +31,9 @@ func init() {
...
@@ -37,10 +31,9 @@ func init() {
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
port
=
p
go
func
(
l
net
.
Listener
)
{
go
func
(
l
net
.
Listener
)
{
s
:=
master
.
NewService
(
chunkPerTask
,
time
.
Second
,
1
)
s
:=
master
.
NewService
(
10
,
time
.
Second
,
1
)
server
:=
rpc
.
NewServer
()
server
:=
rpc
.
NewServer
()
err
:=
server
.
Register
(
s
)
err
:=
server
.
Register
(
s
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -54,67 +47,33 @@ func init() {
...
@@ -54,67 +47,33 @@ func init() {
panic
(
err
)
panic
(
err
)
}
}
}(
l
)
}(
l
)
}
type
addresser
string
func
(
a
addresser
)
Address
()
string
{
f
,
err
:=
os
.
Create
(
path
)
return
string
(
a
)
}
func
TestClientFull
(
t
*
testing
.
T
)
{
const
p
=
"/tmp/master_client_test_0"
f
,
err
:=
os
.
Create
(
p
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
}
}
for
i
:=
0
;
i
<
totalTask
*
chunkPerTask
;
i
++
{
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
.
Write
(
nil
)
for
i
:=
0
;
i
<
total
;
i
++
{
// call Close to force RecordIO writing a chunk.
w
.
Write
([]
byte
{
byte
(
i
)})
w
.
Close
()
}
}
w
.
Close
()
f
.
Close
()
f
.
Close
()
c
:=
master
.
NewClient
(
addresser
(
fmt
.
Sprintf
(
":%d"
,
port
))
)
c
:=
master
.
NewClient
(
master
.
TestAddresser
(
fmt
.
Sprintf
(
":%d"
,
p
)),
10
)
c
.
SetDataset
([]
string
{
p
})
c
.
SetDataset
([]
string
{
p
ath
})
checkOnePass
:=
func
(
i
int
)
{
for
pass
:=
0
;
pass
<
50
;
pass
++
{
var
tasks
[]
master
.
Task
received
:=
make
(
map
[
byte
]
bool
)
for
i
:=
0
;
i
<
total
Task
;
i
++
{
for
i
:=
0
;
i
<
total
;
i
++
{
task
,
err
:=
c
.
GetTask
()
r
:=
c
.
NextRecord
()
if
err
!=
nil
{
if
len
(
r
)
!=
1
{
t
.
Fatal
(
i
,
er
r
)
t
.
Fatal
(
"Length should be 1."
,
r
)
}
}
tasks
=
append
(
tasks
,
task
)
if
received
[
r
[
0
]]
{
t
.
Fatal
(
"Received duplicate."
,
received
,
r
)
}
}
received
[
r
[
0
]]
=
true
_
,
err
=
c
.
GetTask
()
if
err
==
nil
{
t
.
Fatal
(
i
,
"should get error."
)
}
}
err
=
c
.
TaskFinished
(
tasks
[
0
]
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
tasks
=
tasks
[
1
:
]
task
,
err
:=
c
.
GetTask
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
tasks
=
append
(
tasks
,
task
)
for
_
,
task
:=
range
tasks
{
err
=
c
.
TaskFinished
(
task
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
i
,
err
)
}
}
}
for
i
:=
0
;
i
<
10
;
i
++
{
checkOnePass
(
i
)
}
}
}
}
go/master/service.go
浏览文件 @
fa5c3f1f
...
@@ -217,6 +217,16 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
...
@@ -217,6 +217,16 @@ func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
}
}
}
}
// must be called with lock held.
func
(
s
*
Service
)
logFields
()
log
.
Fields
{
return
log
.
Fields
{
"todoLen"
:
len
(
s
.
taskQueues
.
Todo
),
"pendingLen"
:
len
(
s
.
taskQueues
.
Pending
),
"doneLen"
:
len
(
s
.
taskQueues
.
Done
),
"failedLen"
:
len
(
s
.
taskQueues
.
Failed
),
}
}
// GetTask gets a new task from the service.
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
select
{
select
{
...
@@ -230,7 +240,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
...
@@ -230,7 +240,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
if
len
(
s
.
taskQueues
.
Done
)
==
0
{
if
len
(
s
.
taskQueues
.
Done
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
err
:=
errors
.
New
(
"all task failed"
)
err
:=
errors
.
New
(
"all task failed"
)
log
.
W
arningln
(
err
)
log
.
W
ithFields
(
s
.
logFields
())
.
Warningln
(
"All tasks failed."
)
return
err
return
err
}
}
...
@@ -243,12 +253,12 @@ func (s *Service) GetTask(dummy int, task *Task) error {
...
@@ -243,12 +253,12 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// in package. So we need to figure out a way
// in package. So we need to figure out a way
// for client to check this error correctly.
// for client to check this error correctly.
err
:=
errors
.
New
(
"no more available task"
)
err
:=
errors
.
New
(
"no more available task"
)
log
.
W
arningln
(
err
)
log
.
W
ithFields
(
s
.
logFields
())
.
Warningln
(
"No more available task."
)
return
err
return
err
}
}
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
s
.
taskQueues
.
Done
=
nil
s
.
taskQueues
.
Done
=
nil
log
.
Infoln
(
"No more todo task, but trainer is requesting task to do. Move all done task to todo."
)
log
.
WithFields
(
s
.
logFields
())
.
Infoln
(
"No more todo task, but trainer is requesting task to do. Move all done task to todo."
)
}
}
t
:=
s
.
taskQueues
.
Todo
[
0
]
t
:=
s
.
taskQueues
.
Todo
[
0
]
...
@@ -261,7 +271,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
...
@@ -261,7 +271,7 @@ func (s *Service) GetTask(dummy int, task *Task) error {
}
}
*
task
=
t
.
Task
*
task
=
t
.
Task
log
.
Infof
(
"Task #%d dispatched
\n
"
,
task
.
ID
)
log
.
WithFields
(
s
.
logFields
())
.
Infof
(
"Task #%d dispatched.
"
,
task
.
ID
)
time
.
AfterFunc
(
s
.
timeoutDur
,
s
.
checkTimeoutFunc
(
t
.
Task
.
ID
,
t
.
Epoch
))
time
.
AfterFunc
(
s
.
timeoutDur
,
s
.
checkTimeoutFunc
(
t
.
Task
.
ID
,
t
.
Epoch
))
return
nil
return
nil
...
@@ -276,12 +286,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
...
@@ -276,12 +286,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s
.
mu
.
Lock
()
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
defer
s
.
mu
.
Unlock
()
log
.
Infof
(
"Task %d finished
\n
"
,
taskID
)
t
,
ok
:=
s
.
taskQueues
.
Pending
[
taskID
]
t
,
ok
:=
s
.
taskQueues
.
Pending
[
taskID
]
if
!
ok
{
if
!
ok
{
err
:=
errors
.
New
(
"pending task not found"
)
err
:=
errors
.
New
(
"pending task not found"
)
log
.
W
arningln
(
err
)
log
.
W
ithFields
(
s
.
logFields
())
.
Warningln
(
"Pending task #%d not found."
,
taskID
)
return
err
return
err
}
}
...
@@ -290,8 +298,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
...
@@ -290,8 +298,10 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s
.
taskQueues
.
Done
=
append
(
s
.
taskQueues
.
Done
,
t
)
s
.
taskQueues
.
Done
=
append
(
s
.
taskQueues
.
Done
,
t
)
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
log
.
WithFields
(
s
.
logFields
())
.
Infof
(
"Task #%d finished."
,
taskID
)
if
len
(
s
.
taskQueues
.
Pending
)
==
0
&&
len
(
s
.
taskQueues
.
Todo
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
&&
len
(
s
.
taskQueues
.
Todo
)
==
0
{
log
.
Infoln
(
"No more todo and pending task, start a new pass."
)
log
.
WithFields
(
s
.
logFields
())
.
Infoln
(
"No more todo and pending task, start a new pass."
)
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
s
.
taskQueues
.
Done
...
)
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
s
.
taskQueues
.
Done
...
)
s
.
taskQueues
.
Done
=
nil
s
.
taskQueues
.
Done
=
nil
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录