Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
54e8263c
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
1 年多 前同步成功
通知
696
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
54e8263c
编写于
6月 09, 2017
作者:
H
Helin Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement master server client, remove unnecessary dummy variable
上级
72a73ab6
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
174 addition
and
103 deletion
+174
-103
go/cmd/master/master.go
go/cmd/master/master.go
+2
-48
go/master/client.go
go/master/client.go
+10
-4
go/master/client_test.go
go/master/client_test.go
+38
-11
go/master/service.go
go/master/service.go
+107
-14
go/pserver/client.go
go/pserver/client.go
+4
-8
go/pserver/service_test.go
go/pserver/service_test.go
+13
-18
未找到文件。
go/cmd/master/master.go
浏览文件 @
54e8263c
package
main
import
(
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
func
main
()
{
port
:=
flag
.
Int
(
"port"
,
8080
,
"port of the master server."
)
dataset
:=
flag
.
String
(
"training_dataset"
,
""
,
"dataset: comma separated path to RecordIO paths, supports golb patterns."
)
faultTolerance
:=
flag
.
Bool
(
"fault_tolerance"
,
false
,
"enable fault tolerance (requires etcd)."
)
taskTimeoutDur
:=
flag
.
Duration
(
"task_timout_dur"
,
20
*
time
.
Minute
,
"task timout duration."
)
taskTimeoutMax
:=
flag
.
Int
(
"task_timeout_max"
,
3
,
"max timtout count for each task before it being declared failed task."
)
chunkPerTask
:=
flag
.
Int
(
"chunk_per_task"
,
10
,
"chunk per task."
)
flag
.
Parse
()
if
*
dataset
==
""
{
panic
(
"no dataset specified."
)
}
if
*
faultTolerance
{
panic
(
"fault tolernance not implemented."
)
}
var
chunks
[]
master
.
Chunk
var
paths
[]
string
ss
:=
strings
.
Split
(
*
dataset
,
","
)
fmt
.
Println
(
ss
)
for
_
,
s
:=
range
ss
{
match
,
err
:=
filepath
.
Glob
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
paths
=
append
(
paths
,
match
...
)
}
if
len
(
paths
)
==
0
{
panic
(
"no valid datset specified."
)
}
for
_
,
path
:=
range
paths
{
f
,
err
:=
os
.
Open
(
path
)
if
err
!=
nil
{
panic
(
err
)
}
index
,
err
:=
recordio
.
LoadIndex
(
f
)
if
err
!=
nil
{
panic
(
err
)
}
f
.
Close
()
count
:=
index
.
NumChunks
()
for
i
:=
0
;
i
<
count
;
i
++
{
chunk
:=
master
.
Chunk
{
Path
:
path
,
Index
:
*
index
.
ChunkIndex
(
i
),
}
chunks
=
append
(
chunks
,
chunk
)
}
}
s
:=
master
.
NewService
(
chunks
,
*
chunkPerTask
,
*
taskTimeoutDur
,
*
taskTimeoutMax
)
s
:=
master
.
NewService
(
*
chunkPerTask
,
*
taskTimeoutDur
,
*
taskTimeoutMax
)
err
:=
rpc
.
Register
(
s
)
if
err
!=
nil
{
panic
(
err
)
...
...
go/master/client.go
浏览文件 @
54e8263c
...
...
@@ -59,16 +59,22 @@ func (c *Client) monitorMaster(addr Addresser) {
}
}
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func
(
c
*
Client
)
SetDataset
(
globPaths
[]
string
)
error
{
return
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
}
// GetTask gets a new task from the master server.
func
(
c
*
Client
)
GetTask
()
(
Task
,
error
)
{
var
dummy
int
var
t
Task
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
dummy
,
&
t
)
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
0
,
&
t
)
return
t
,
err
}
// TaskFinished tells the master server a task is finished.
func
(
c
*
Client
)
TaskFinished
(
taskID
int
)
error
{
var
dummy
int
return
c
.
conn
.
Call
(
"Service.TaskFinished"
,
taskID
,
&
dummy
)
return
c
.
conn
.
Call
(
"Service.TaskFinished"
,
taskID
,
nil
)
}
go/master/client_test.go
浏览文件 @
54e8263c
...
...
@@ -5,12 +5,14 @@ import (
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
const
(
...
...
@@ -34,8 +36,7 @@ func init() {
port
=
p
go
func
(
l
net
.
Listener
)
{
chunks
:=
make
([]
master
.
Chunk
,
totalTask
)
s
:=
master
.
NewService
(
chunks
,
chunkPerTask
,
time
.
Second
,
1
)
s
:=
master
.
NewService
(
chunkPerTask
,
time
.
Second
,
1
)
server
:=
rpc
.
NewServer
()
err
:=
server
.
Register
(
s
)
if
err
!=
nil
{
...
...
@@ -58,21 +59,47 @@ func (a addresser) Address() string {
}
func
TestClientFull
(
t
*
testing
.
T
)
{
const
p
=
"/tmp/master_client_test_0"
f
,
err
:=
os
.
Create
(
p
)
if
err
!=
nil
{
panic
(
err
)
}
for
i
:=
0
;
i
<
totalTask
*
chunkPerTask
;
i
++
{
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
.
Write
(
nil
)
// call Close to force RecordIO writing a chunk.
w
.
Close
()
}
f
.
Close
()
c
:=
master
.
NewClient
(
addresser
(
fmt
.
Sprintf
(
":%d"
,
port
)))
c
.
SetDataset
([]
string
{
p
})
for
i
:=
0
;
i
<
5
*
totalTask
/
chunkPerTask
;
i
++
{
task
,
err
:=
c
.
GetTask
()
if
err
!=
nil
{
panic
(
err
)
checkOnePass
:=
func
(
i
int
)
{
var
tasks
[]
master
.
Task
for
i
:=
0
;
i
<
totalTask
;
i
++
{
task
,
err
:=
c
.
GetTask
()
if
err
!=
nil
{
t
.
Fatal
(
i
,
err
)
}
tasks
=
append
(
tasks
,
task
)
}
if
len
(
task
.
Chunks
)
!=
chunkPerTask
{
t
.
Fatal
(
"wrong number of chunk per task"
,
len
(
task
.
Chunks
))
_
,
err
=
c
.
GetTask
()
if
err
==
nil
{
t
.
Fatal
(
i
,
"should get error."
)
}
err
=
c
.
TaskFinished
(
task
.
ID
)
if
err
!=
nil
{
panic
(
err
)
for
_
,
task
:=
range
tasks
{
err
=
c
.
TaskFinished
(
task
.
ID
)
if
err
!=
nil
{
t
.
Fatal
(
i
,
err
)
}
}
}
for
i
:=
0
;
i
<
10
;
i
++
{
checkOnePass
(
i
)
}
}
go/master/service.go
浏览文件 @
54e8263c
...
...
@@ -3,6 +3,8 @@ package master
import
(
"errors"
"log"
"os"
"path/filepath"
"sync"
"time"
...
...
@@ -13,18 +15,15 @@ const (
targetTaskCount
=
300
)
// errors
var
(
ErrNoMoreTask
=
errors
.
New
(
"no more task for current pass"
)
ErrPendingTaskNotFound
=
errors
.
New
(
"pending task not found"
)
)
// Service is the master server service.
type
Service
struct
{
timeoutDur
time
.
Duration
timeoutMax
int
chunksPerTask
int
timeoutDur
time
.
Duration
timeoutMax
int
ready
chan
struct
{}
mu
sync
.
Mutex
initBegan
bool
taskQueues
taskQueues
}
...
...
@@ -63,13 +62,14 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
}
// NewService creates a new service.
func
NewService
(
chunks
[]
Chunk
,
chunks
PerTask
int
,
timeoutDur
time
.
Duration
,
timeoutMax
int
)
*
Service
{
func
NewService
(
chunksPerTask
int
,
timeoutDur
time
.
Duration
,
timeoutMax
int
)
*
Service
{
s
:=
&
Service
{}
s
.
chunksPerTask
=
chunksPerTask
s
.
timeoutDur
=
timeoutDur
s
.
timeoutMax
=
timeoutMax
s
.
taskQueues
=
taskQueues
{}
s
.
taskQueues
.
Pending
=
make
(
map
[
int
]
taskEntry
)
s
.
taskQueues
.
Todo
=
partition
(
chunks
,
chunksPerTask
)
s
.
ready
=
make
(
chan
struct
{}
)
return
s
}
...
...
@@ -104,13 +104,102 @@ func (s *Service) snapshot() error {
return
nil
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func
(
s
*
Service
)
SetDataset
(
globPaths
[]
string
,
dummy
*
int
)
error
{
if
len
(
globPaths
)
==
0
{
return
errors
.
New
(
"no dataset specified"
)
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
s
.
initBegan
{
// SetDataset already called. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return
nil
}
s
.
initBegan
=
true
var
chunks
[]
Chunk
var
paths
[]
string
for
_
,
s
:=
range
globPaths
{
match
,
err
:=
filepath
.
Glob
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
paths
=
append
(
paths
,
match
...
)
}
if
len
(
paths
)
==
0
{
return
errors
.
New
(
"no valid datset specified"
)
}
for
_
,
path
:=
range
paths
{
f
,
err
:=
os
.
Open
(
path
)
if
err
!=
nil
{
panic
(
err
)
}
index
,
err
:=
recordio
.
LoadIndex
(
f
)
if
err
!=
nil
{
return
err
}
err
=
f
.
Close
()
if
err
!=
nil
{
return
err
}
count
:=
index
.
NumChunks
()
for
i
:=
0
;
i
<
count
;
i
++
{
chunk
:=
Chunk
{
Path
:
path
,
Index
:
*
index
.
ChunkIndex
(
i
),
}
chunks
=
append
(
chunks
,
chunk
)
}
}
s
.
taskQueues
.
Todo
=
partition
(
chunks
,
s
.
chunksPerTask
)
err
:=
s
.
snapshot
()
if
err
!=
nil
{
return
err
}
close
(
s
.
ready
)
return
nil
}
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
select
{
case
<-
s
.
ready
:
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
len
(
s
.
taskQueues
.
Todo
)
==
0
{
return
ErrNoMoreTask
if
len
(
s
.
taskQueues
.
Done
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
return
errors
.
New
(
"all task failed"
)
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// erros like io.EOF. Because interface don't
// have same dynamic value when in different
// process.
return
errors
.
New
(
"no more available task"
)
}
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
s
.
taskQueues
.
Todo
=
nil
}
t
:=
s
.
taskQueues
.
Todo
[
0
]
...
...
@@ -163,12 +252,16 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// TaskFinished tell the service that a task is finished.
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
dummy
*
int
)
error
{
select
{
case
<-
s
.
ready
:
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
t
,
ok
:=
s
.
taskQueues
.
Pending
[
taskID
]
if
!
ok
{
return
ErrPendingTaskNotFound
return
errors
.
New
(
"pending task not found"
)
}
// task finished, reset timeout
...
...
@@ -176,8 +269,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s
.
taskQueues
.
Done
=
append
(
s
.
taskQueues
.
Done
,
t
)
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
if
len
(
s
.
taskQueues
.
Todo
)
==
0
{
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
s
.
taskQueues
.
Done
...
)
s
.
taskQueues
.
Done
=
nil
}
...
...
go/pserver/client.go
浏览文件 @
54e8263c
...
...
@@ -102,16 +102,14 @@ func (c *Client) BeginInitParams() bool {
// InitParam initializes the parameter on parameter servers.
func
(
c
*
Client
)
InitParam
(
paramWithConfigs
ParameterWithConfig
)
error
{
var
dummy
int
return
c
.
pservers
[
c
.
partition
(
paramWithConfigs
.
Param
.
Name
)]
.
Call
(
"Service.InitParam"
,
paramWithConfigs
,
&
dummy
)
return
c
.
pservers
[
c
.
partition
(
paramWithConfigs
.
Param
.
Name
)]
.
Call
(
"Service.InitParam"
,
paramWithConfigs
,
nil
)
}
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
func
(
c
*
Client
)
FinishInitParams
()
error
{
for
_
,
p
:=
range
c
.
pservers
{
var
dummy
int
err
:=
p
.
Call
(
"Service.FinishInitParams"
,
dummy
,
&
dummy
)
err
:=
p
.
Call
(
"Service.FinishInitParams"
,
0
,
nil
)
if
err
!=
nil
{
return
err
}
...
...
@@ -125,8 +123,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
errCh
:=
make
(
chan
error
,
len
(
grads
))
for
_
,
g
:=
range
grads
{
go
func
(
g
Gradient
)
{
var
dummy
int
err
:=
c
.
pservers
[
c
.
partition
(
g
.
Name
)]
.
Call
(
"Service.SendGrad"
,
g
,
&
dummy
)
err
:=
c
.
pservers
[
c
.
partition
(
g
.
Name
)]
.
Call
(
"Service.SendGrad"
,
g
,
nil
)
errCh
<-
err
}(
g
)
}
...
...
@@ -205,8 +202,7 @@ func (c *Client) Save(path string) error {
errCh
:=
make
(
chan
error
,
len
(
c
.
pservers
))
for
_
,
p
:=
range
c
.
pservers
{
var
dummy
int
err
:=
p
.
Call
(
"Service.Save"
,
path
,
&
dummy
)
err
:=
p
.
Call
(
"Service.Save"
,
path
,
nil
)
errCh
<-
err
}
...
...
go/pserver/service_test.go
浏览文件 @
54e8263c
...
...
@@ -15,8 +15,7 @@ func TestFull(t *testing.T) {
p
.
Name
=
"param_a"
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
ElementType
=
pserver
.
Int32
var
dummy
int
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
,
Config
:
nil
},
&
dummy
)
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p
,
nil
},
nil
)
if
err
!=
nil
{
t
.
FailNow
()
}
...
...
@@ -25,12 +24,12 @@ func TestFull(t *testing.T) {
p1
.
Name
=
"param_b"
p1
.
Content
=
[]
byte
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}
p1
.
ElementType
=
pserver
.
Float32
err
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p1
,
Config
:
nil
},
&
dummy
)
err
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p1
,
nil
},
nil
)
if
err
!=
nil
{
t
.
FailNow
()
}
err
=
s
.
FinishInitParams
(
0
,
&
dummy
)
err
=
s
.
FinishInitParams
(
0
,
nil
)
if
err
!=
nil
{
t
.
FailNow
()
}
...
...
@@ -46,11 +45,11 @@ func TestFull(t *testing.T) {
}
g1
,
g2
:=
pserver
.
Gradient
(
p1
),
pserver
.
Gradient
(
p
)
err
=
s
.
SendGrad
(
g1
,
&
dummy
)
err
=
s
.
SendGrad
(
g1
,
nil
)
if
err
!=
nil
{
t
.
FailNow
()
}
err
=
s
.
SendGrad
(
g2
,
&
dummy
)
err
=
s
.
SendGrad
(
g2
,
nil
)
if
err
!=
nil
{
t
.
FailNow
()
...
...
@@ -74,23 +73,21 @@ func TestFull(t *testing.T) {
func
TestMultipleInit
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
var
dummy
int
err
:=
s
.
FinishInitParams
(
0
,
&
dummy
)
err
:=
s
.
FinishInitParams
(
0
,
nil
)
if
err
!=
nil
{
t
.
FailNow
()
}
err
=
s
.
FinishInitParams
(
0
,
&
dummy
)
if
err
.
Error
()
!=
pserver
.
AlreadyInitialized
{
err
=
s
.
FinishInitParams
(
0
,
nil
)
if
err
!=
pserver
.
Err
AlreadyInitialized
{
t
.
FailNow
()
}
}
func
TestUninitialized
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
var
dummy
int
err
:=
s
.
SendGrad
(
pserver
.
Gradient
{},
&
dummy
)
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
err
:=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
if
err
!=
pserver
.
ErrUninitialized
{
t
.
FailNow
()
}
}
...
...
@@ -112,8 +109,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Add
(
1
)
go
func
()
{
var
dummy
int
err
:=
s
.
Save
(
""
,
&
dummy
)
err
:=
s
.
Save
(
""
,
nil
)
if
err
!=
nil
{
t
.
FailNow
()
}
...
...
@@ -134,13 +130,12 @@ func TestBlockUntilInitialized(t *testing.T) {
p
.
Name
=
"param_a"
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
ElementType
=
pserver
.
Int32
var
dummy
int
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
,
Config
:
nil
},
&
dummy
)
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p
,
nil
},
nil
)
if
err
!=
nil
{
t
.
FailNow
()
}
err
=
s
.
FinishInitParams
(
0
,
&
dummy
)
err
=
s
.
FinishInitParams
(
0
,
nil
)
if
err
!=
nil
{
t
.
FailNow
()
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录