Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
54e8263c
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
54e8263c
编写于
6月 09, 2017
作者:
H
Helin Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement master server client, remove unnecessary dummy variable
上级
72a73ab6
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
174 addition
and
103 deletion
+174
-103
go/cmd/master/master.go
go/cmd/master/master.go
+2
-48
go/master/client.go
go/master/client.go
+10
-4
go/master/client_test.go
go/master/client_test.go
+38
-11
go/master/service.go
go/master/service.go
+107
-14
go/pserver/client.go
go/pserver/client.go
+4
-8
go/pserver/service_test.go
go/pserver/service_test.go
+13
-18
未找到文件。
go/cmd/master/master.go
浏览文件 @
54e8263c
package
main
package
main
import
(
import
(
"fmt"
"net"
"net"
"net/http"
"net/http"
"net/rpc"
"net/rpc"
"os"
"path/filepath"
"strconv"
"strconv"
"strings"
"time"
"time"
"github.com/namsral/flag"
"github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
)
func
main
()
{
func
main
()
{
port
:=
flag
.
Int
(
"port"
,
8080
,
"port of the master server."
)
port
:=
flag
.
Int
(
"port"
,
8080
,
"port of the master server."
)
dataset
:=
flag
.
String
(
"training_dataset"
,
""
,
"dataset: comma separated path to RecordIO paths, supports golb patterns."
)
faultTolerance
:=
flag
.
Bool
(
"fault_tolerance"
,
false
,
"enable fault tolerance (requires etcd)."
)
faultTolerance
:=
flag
.
Bool
(
"fault_tolerance"
,
false
,
"enable fault tolerance (requires etcd)."
)
taskTimeoutDur
:=
flag
.
Duration
(
"task_timout_dur"
,
20
*
time
.
Minute
,
"task timout duration."
)
taskTimeoutDur
:=
flag
.
Duration
(
"task_timout_dur"
,
20
*
time
.
Minute
,
"task timout duration."
)
taskTimeoutMax
:=
flag
.
Int
(
"task_timeout_max"
,
3
,
"max timtout count for each task before it being declared failed task."
)
taskTimeoutMax
:=
flag
.
Int
(
"task_timeout_max"
,
3
,
"max timtout count for each task before it being declared failed task."
)
chunkPerTask
:=
flag
.
Int
(
"chunk_per_task"
,
10
,
"chunk per task."
)
chunkPerTask
:=
flag
.
Int
(
"chunk_per_task"
,
10
,
"chunk per task."
)
flag
.
Parse
()
flag
.
Parse
()
if
*
dataset
==
""
{
panic
(
"no dataset specified."
)
}
if
*
faultTolerance
{
if
*
faultTolerance
{
panic
(
"fault tolernance not implemented."
)
panic
(
"fault tolernance not implemented."
)
}
var
chunks
[]
master
.
Chunk
var
paths
[]
string
ss
:=
strings
.
Split
(
*
dataset
,
","
)
fmt
.
Println
(
ss
)
for
_
,
s
:=
range
ss
{
match
,
err
:=
filepath
.
Glob
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
paths
=
append
(
paths
,
match
...
)
}
if
len
(
paths
)
==
0
{
panic
(
"no valid datset specified."
)
}
for
_
,
path
:=
range
paths
{
f
,
err
:=
os
.
Open
(
path
)
if
err
!=
nil
{
panic
(
err
)
}
index
,
err
:=
recordio
.
LoadIndex
(
f
)
if
err
!=
nil
{
panic
(
err
)
}
f
.
Close
()
count
:=
index
.
NumChunks
()
for
i
:=
0
;
i
<
count
;
i
++
{
chunk
:=
master
.
Chunk
{
Path
:
path
,
Index
:
*
index
.
ChunkIndex
(
i
),
}
chunks
=
append
(
chunks
,
chunk
)
}
}
}
s
:=
master
.
NewService
(
chunks
,
*
chunkPerTask
,
*
taskTimeoutDur
,
*
taskTimeoutMax
)
s
:=
master
.
NewService
(
*
chunkPerTask
,
*
taskTimeoutDur
,
*
taskTimeoutMax
)
err
:=
rpc
.
Register
(
s
)
err
:=
rpc
.
Register
(
s
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
panic
(
err
)
...
...
go/master/client.go
浏览文件 @
54e8263c
...
@@ -59,16 +59,22 @@ func (c *Client) monitorMaster(addr Addresser) {
...
@@ -59,16 +59,22 @@ func (c *Client) monitorMaster(addr Addresser) {
}
}
}
}
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func
(
c
*
Client
)
SetDataset
(
globPaths
[]
string
)
error
{
return
c
.
conn
.
Call
(
"Service.SetDataset"
,
globPaths
,
nil
)
}
// GetTask gets a new task from the master server.
// GetTask gets a new task from the master server.
func
(
c
*
Client
)
GetTask
()
(
Task
,
error
)
{
func
(
c
*
Client
)
GetTask
()
(
Task
,
error
)
{
var
dummy
int
var
t
Task
var
t
Task
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
dummy
,
&
t
)
err
:=
c
.
conn
.
Call
(
"Service.GetTask"
,
0
,
&
t
)
return
t
,
err
return
t
,
err
}
}
// TaskFinished tells the master server a task is finished.
// TaskFinished tells the master server a task is finished.
func
(
c
*
Client
)
TaskFinished
(
taskID
int
)
error
{
func
(
c
*
Client
)
TaskFinished
(
taskID
int
)
error
{
var
dummy
int
return
c
.
conn
.
Call
(
"Service.TaskFinished"
,
taskID
,
nil
)
return
c
.
conn
.
Call
(
"Service.TaskFinished"
,
taskID
,
&
dummy
)
}
}
go/master/client_test.go
浏览文件 @
54e8263c
...
@@ -5,12 +5,14 @@ import (
...
@@ -5,12 +5,14 @@ import (
"net"
"net"
"net/http"
"net/http"
"net/rpc"
"net/rpc"
"os"
"strconv"
"strconv"
"strings"
"strings"
"testing"
"testing"
"time"
"time"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
)
const
(
const
(
...
@@ -34,8 +36,7 @@ func init() {
...
@@ -34,8 +36,7 @@ func init() {
port
=
p
port
=
p
go
func
(
l
net
.
Listener
)
{
go
func
(
l
net
.
Listener
)
{
chunks
:=
make
([]
master
.
Chunk
,
totalTask
)
s
:=
master
.
NewService
(
chunkPerTask
,
time
.
Second
,
1
)
s
:=
master
.
NewService
(
chunks
,
chunkPerTask
,
time
.
Second
,
1
)
server
:=
rpc
.
NewServer
()
server
:=
rpc
.
NewServer
()
err
:=
server
.
Register
(
s
)
err
:=
server
.
Register
(
s
)
if
err
!=
nil
{
if
err
!=
nil
{
...
@@ -58,21 +59,47 @@ func (a addresser) Address() string {
...
@@ -58,21 +59,47 @@ func (a addresser) Address() string {
}
}
func
TestClientFull
(
t
*
testing
.
T
)
{
func
TestClientFull
(
t
*
testing
.
T
)
{
const
p
=
"/tmp/master_client_test_0"
f
,
err
:=
os
.
Create
(
p
)
if
err
!=
nil
{
panic
(
err
)
}
for
i
:=
0
;
i
<
totalTask
*
chunkPerTask
;
i
++
{
w
:=
recordio
.
NewWriter
(
f
,
-
1
,
-
1
)
w
.
Write
(
nil
)
// call Close to force RecordIO writing a chunk.
w
.
Close
()
}
f
.
Close
()
c
:=
master
.
NewClient
(
addresser
(
fmt
.
Sprintf
(
":%d"
,
port
)))
c
:=
master
.
NewClient
(
addresser
(
fmt
.
Sprintf
(
":%d"
,
port
)))
c
.
SetDataset
([]
string
{
p
})
for
i
:=
0
;
i
<
5
*
totalTask
/
chunkPerTask
;
i
++
{
checkOnePass
:=
func
(
i
int
)
{
var
tasks
[]
master
.
Task
for
i
:=
0
;
i
<
totalTask
;
i
++
{
task
,
err
:=
c
.
GetTask
()
task
,
err
:=
c
.
GetTask
()
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
t
.
Fatal
(
i
,
err
)
}
tasks
=
append
(
tasks
,
task
)
}
}
if
len
(
task
.
Chunks
)
!=
chunkPerTask
{
_
,
err
=
c
.
GetTask
()
t
.
Fatal
(
"wrong number of chunk per task"
,
len
(
task
.
Chunks
))
if
err
==
nil
{
t
.
Fatal
(
i
,
"should get error."
)
}
}
for
_
,
task
:=
range
tasks
{
err
=
c
.
TaskFinished
(
task
.
ID
)
err
=
c
.
TaskFinished
(
task
.
ID
)
if
err
!=
nil
{
if
err
!=
nil
{
panic
(
err
)
t
.
Fatal
(
i
,
err
)
}
}
}
}
}
for
i
:=
0
;
i
<
10
;
i
++
{
checkOnePass
(
i
)
}
}
}
go/master/service.go
浏览文件 @
54e8263c
...
@@ -3,6 +3,8 @@ package master
...
@@ -3,6 +3,8 @@ package master
import
(
import
(
"errors"
"errors"
"log"
"log"
"os"
"path/filepath"
"sync"
"sync"
"time"
"time"
...
@@ -13,18 +15,15 @@ const (
...
@@ -13,18 +15,15 @@ const (
targetTaskCount
=
300
targetTaskCount
=
300
)
)
// errors
var
(
ErrNoMoreTask
=
errors
.
New
(
"no more task for current pass"
)
ErrPendingTaskNotFound
=
errors
.
New
(
"pending task not found"
)
)
// Service is the master server service.
// Service is the master server service.
type
Service
struct
{
type
Service
struct
{
chunksPerTask
int
timeoutDur
time
.
Duration
timeoutDur
time
.
Duration
timeoutMax
int
timeoutMax
int
ready
chan
struct
{}
mu
sync
.
Mutex
mu
sync
.
Mutex
initBegan
bool
taskQueues
taskQueues
taskQueues
taskQueues
}
}
...
@@ -63,13 +62,14 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
...
@@ -63,13 +62,14 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
}
}
// NewService creates a new service.
// NewService creates a new service.
func
NewService
(
chunks
[]
Chunk
,
chunks
PerTask
int
,
timeoutDur
time
.
Duration
,
timeoutMax
int
)
*
Service
{
func
NewService
(
chunksPerTask
int
,
timeoutDur
time
.
Duration
,
timeoutMax
int
)
*
Service
{
s
:=
&
Service
{}
s
:=
&
Service
{}
s
.
chunksPerTask
=
chunksPerTask
s
.
timeoutDur
=
timeoutDur
s
.
timeoutDur
=
timeoutDur
s
.
timeoutMax
=
timeoutMax
s
.
timeoutMax
=
timeoutMax
s
.
taskQueues
=
taskQueues
{}
s
.
taskQueues
=
taskQueues
{}
s
.
taskQueues
.
Pending
=
make
(
map
[
int
]
taskEntry
)
s
.
taskQueues
.
Pending
=
make
(
map
[
int
]
taskEntry
)
s
.
taskQueues
.
Todo
=
partition
(
chunks
,
chunksPerTask
)
s
.
ready
=
make
(
chan
struct
{}
)
return
s
return
s
}
}
...
@@ -104,13 +104,102 @@ func (s *Service) snapshot() error {
...
@@ -104,13 +104,102 @@ func (s *Service) snapshot() error {
return
nil
return
nil
}
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func
(
s
*
Service
)
SetDataset
(
globPaths
[]
string
,
dummy
*
int
)
error
{
if
len
(
globPaths
)
==
0
{
return
errors
.
New
(
"no dataset specified"
)
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
s
.
initBegan
{
// SetDataset already called. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return
nil
}
s
.
initBegan
=
true
var
chunks
[]
Chunk
var
paths
[]
string
for
_
,
s
:=
range
globPaths
{
match
,
err
:=
filepath
.
Glob
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
paths
=
append
(
paths
,
match
...
)
}
if
len
(
paths
)
==
0
{
return
errors
.
New
(
"no valid datset specified"
)
}
for
_
,
path
:=
range
paths
{
f
,
err
:=
os
.
Open
(
path
)
if
err
!=
nil
{
panic
(
err
)
}
index
,
err
:=
recordio
.
LoadIndex
(
f
)
if
err
!=
nil
{
return
err
}
err
=
f
.
Close
()
if
err
!=
nil
{
return
err
}
count
:=
index
.
NumChunks
()
for
i
:=
0
;
i
<
count
;
i
++
{
chunk
:=
Chunk
{
Path
:
path
,
Index
:
*
index
.
ChunkIndex
(
i
),
}
chunks
=
append
(
chunks
,
chunk
)
}
}
s
.
taskQueues
.
Todo
=
partition
(
chunks
,
s
.
chunksPerTask
)
err
:=
s
.
snapshot
()
if
err
!=
nil
{
return
err
}
close
(
s
.
ready
)
return
nil
}
// GetTask gets a new task from the service.
// GetTask gets a new task from the service.
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
func
(
s
*
Service
)
GetTask
(
dummy
int
,
task
*
Task
)
error
{
select
{
case
<-
s
.
ready
:
}
s
.
mu
.
Lock
()
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
defer
s
.
mu
.
Unlock
()
if
len
(
s
.
taskQueues
.
Todo
)
==
0
{
if
len
(
s
.
taskQueues
.
Todo
)
==
0
{
return
ErrNoMoreTask
if
len
(
s
.
taskQueues
.
Done
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
return
errors
.
New
(
"all task failed"
)
}
// TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// erros like io.EOF. Because interface don't
// have same dynamic value when in different
// process.
return
errors
.
New
(
"no more available task"
)
}
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
s
.
taskQueues
.
Todo
=
nil
}
}
t
:=
s
.
taskQueues
.
Todo
[
0
]
t
:=
s
.
taskQueues
.
Todo
[
0
]
...
@@ -163,12 +252,16 @@ func (s *Service) GetTask(dummy int, task *Task) error {
...
@@ -163,12 +252,16 @@ func (s *Service) GetTask(dummy int, task *Task) error {
// TaskFinished tell the service that a task is finished.
// TaskFinished tell the service that a task is finished.
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
TaskFinished
(
taskID
int
,
dummy
*
int
)
error
{
select
{
case
<-
s
.
ready
:
}
s
.
mu
.
Lock
()
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
defer
s
.
mu
.
Unlock
()
t
,
ok
:=
s
.
taskQueues
.
Pending
[
taskID
]
t
,
ok
:=
s
.
taskQueues
.
Pending
[
taskID
]
if
!
ok
{
if
!
ok
{
return
ErrPendingTaskNotFound
return
errors
.
New
(
"pending task not found"
)
}
}
// task finished, reset timeout
// task finished, reset timeout
...
@@ -176,8 +269,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
...
@@ -176,8 +269,8 @@ func (s *Service) TaskFinished(taskID int, dummy *int) error {
s
.
taskQueues
.
Done
=
append
(
s
.
taskQueues
.
Done
,
t
)
s
.
taskQueues
.
Done
=
append
(
s
.
taskQueues
.
Done
,
t
)
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
delete
(
s
.
taskQueues
.
Pending
,
taskID
)
if
len
(
s
.
taskQueues
.
Todo
)
==
0
{
if
len
(
s
.
taskQueues
.
Pending
)
==
0
{
s
.
taskQueues
.
Todo
=
s
.
taskQueues
.
Done
s
.
taskQueues
.
Todo
=
append
(
s
.
taskQueues
.
Todo
,
s
.
taskQueues
.
Done
...
)
s
.
taskQueues
.
Done
=
nil
s
.
taskQueues
.
Done
=
nil
}
}
...
...
go/pserver/client.go
浏览文件 @
54e8263c
...
@@ -102,16 +102,14 @@ func (c *Client) BeginInitParams() bool {
...
@@ -102,16 +102,14 @@ func (c *Client) BeginInitParams() bool {
// InitParam initializes the parameter on parameter servers.
// InitParam initializes the parameter on parameter servers.
func
(
c
*
Client
)
InitParam
(
paramWithConfigs
ParameterWithConfig
)
error
{
func
(
c
*
Client
)
InitParam
(
paramWithConfigs
ParameterWithConfig
)
error
{
var
dummy
int
return
c
.
pservers
[
c
.
partition
(
paramWithConfigs
.
Param
.
Name
)]
.
Call
(
"Service.InitParam"
,
paramWithConfigs
,
nil
)
return
c
.
pservers
[
c
.
partition
(
paramWithConfigs
.
Param
.
Name
)]
.
Call
(
"Service.InitParam"
,
paramWithConfigs
,
&
dummy
)
}
}
// FinishInitParams tells parameter servers client has sent all
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
// parameters to parameter servers as initialization.
func
(
c
*
Client
)
FinishInitParams
()
error
{
func
(
c
*
Client
)
FinishInitParams
()
error
{
for
_
,
p
:=
range
c
.
pservers
{
for
_
,
p
:=
range
c
.
pservers
{
var
dummy
int
err
:=
p
.
Call
(
"Service.FinishInitParams"
,
0
,
nil
)
err
:=
p
.
Call
(
"Service.FinishInitParams"
,
dummy
,
&
dummy
)
if
err
!=
nil
{
if
err
!=
nil
{
return
err
return
err
}
}
...
@@ -125,8 +123,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
...
@@ -125,8 +123,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
errCh
:=
make
(
chan
error
,
len
(
grads
))
errCh
:=
make
(
chan
error
,
len
(
grads
))
for
_
,
g
:=
range
grads
{
for
_
,
g
:=
range
grads
{
go
func
(
g
Gradient
)
{
go
func
(
g
Gradient
)
{
var
dummy
int
err
:=
c
.
pservers
[
c
.
partition
(
g
.
Name
)]
.
Call
(
"Service.SendGrad"
,
g
,
nil
)
err
:=
c
.
pservers
[
c
.
partition
(
g
.
Name
)]
.
Call
(
"Service.SendGrad"
,
g
,
&
dummy
)
errCh
<-
err
errCh
<-
err
}(
g
)
}(
g
)
}
}
...
@@ -205,8 +202,7 @@ func (c *Client) Save(path string) error {
...
@@ -205,8 +202,7 @@ func (c *Client) Save(path string) error {
errCh
:=
make
(
chan
error
,
len
(
c
.
pservers
))
errCh
:=
make
(
chan
error
,
len
(
c
.
pservers
))
for
_
,
p
:=
range
c
.
pservers
{
for
_
,
p
:=
range
c
.
pservers
{
var
dummy
int
err
:=
p
.
Call
(
"Service.Save"
,
path
,
nil
)
err
:=
p
.
Call
(
"Service.Save"
,
path
,
&
dummy
)
errCh
<-
err
errCh
<-
err
}
}
...
...
go/pserver/service_test.go
浏览文件 @
54e8263c
...
@@ -15,8 +15,7 @@ func TestFull(t *testing.T) {
...
@@ -15,8 +15,7 @@ func TestFull(t *testing.T) {
p
.
Name
=
"param_a"
p
.
Name
=
"param_a"
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
ElementType
=
pserver
.
Int32
p
.
ElementType
=
pserver
.
Int32
var
dummy
int
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p
,
nil
},
nil
)
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
,
Config
:
nil
},
&
dummy
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
...
@@ -25,12 +24,12 @@ func TestFull(t *testing.T) {
...
@@ -25,12 +24,12 @@ func TestFull(t *testing.T) {
p1
.
Name
=
"param_b"
p1
.
Name
=
"param_b"
p1
.
Content
=
[]
byte
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}
p1
.
Content
=
[]
byte
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
}
p1
.
ElementType
=
pserver
.
Float32
p1
.
ElementType
=
pserver
.
Float32
err
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p1
,
Config
:
nil
},
&
dummy
)
err
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p1
,
nil
},
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
err
=
s
.
FinishInitParams
(
0
,
&
dummy
)
err
=
s
.
FinishInitParams
(
0
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
...
@@ -46,11 +45,11 @@ func TestFull(t *testing.T) {
...
@@ -46,11 +45,11 @@ func TestFull(t *testing.T) {
}
}
g1
,
g2
:=
pserver
.
Gradient
(
p1
),
pserver
.
Gradient
(
p
)
g1
,
g2
:=
pserver
.
Gradient
(
p1
),
pserver
.
Gradient
(
p
)
err
=
s
.
SendGrad
(
g1
,
&
dummy
)
err
=
s
.
SendGrad
(
g1
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
err
=
s
.
SendGrad
(
g2
,
&
dummy
)
err
=
s
.
SendGrad
(
g2
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
...
@@ -74,23 +73,21 @@ func TestFull(t *testing.T) {
...
@@ -74,23 +73,21 @@ func TestFull(t *testing.T) {
func
TestMultipleInit
(
t
*
testing
.
T
)
{
func
TestMultipleInit
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
s
:=
pserver
.
NewService
()
var
dummy
int
err
:=
s
.
FinishInitParams
(
0
,
nil
)
err
:=
s
.
FinishInitParams
(
0
,
&
dummy
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
err
=
s
.
FinishInitParams
(
0
,
&
dummy
)
err
=
s
.
FinishInitParams
(
0
,
nil
)
if
err
.
Error
()
!=
pserver
.
AlreadyInitialized
{
if
err
!=
pserver
.
Err
AlreadyInitialized
{
t
.
FailNow
()
t
.
FailNow
()
}
}
}
}
func
TestUninitialized
(
t
*
testing
.
T
)
{
func
TestUninitialized
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
s
:=
pserver
.
NewService
()
var
dummy
int
err
:=
s
.
SendGrad
(
pserver
.
Gradient
{},
nil
)
err
:=
s
.
SendGrad
(
pserver
.
Gradient
{},
&
dummy
)
if
err
!=
pserver
.
ErrUninitialized
{
if
err
.
Error
()
!=
pserver
.
Uninitialized
{
t
.
FailNow
()
t
.
FailNow
()
}
}
}
}
...
@@ -112,8 +109,7 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -112,8 +109,7 @@ func TestBlockUntilInitialized(t *testing.T) {
wg
.
Add
(
1
)
wg
.
Add
(
1
)
go
func
()
{
go
func
()
{
var
dummy
int
err
:=
s
.
Save
(
""
,
nil
)
err
:=
s
.
Save
(
""
,
&
dummy
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
...
@@ -134,13 +130,12 @@ func TestBlockUntilInitialized(t *testing.T) {
...
@@ -134,13 +130,12 @@ func TestBlockUntilInitialized(t *testing.T) {
p
.
Name
=
"param_a"
p
.
Name
=
"param_a"
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
ElementType
=
pserver
.
Int32
p
.
ElementType
=
pserver
.
Int32
var
dummy
int
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p
,
nil
},
nil
)
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
,
Config
:
nil
},
&
dummy
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
err
=
s
.
FinishInitParams
(
0
,
&
dummy
)
err
=
s
.
FinishInitParams
(
0
,
nil
)
if
err
!=
nil
{
if
err
!=
nil
{
t
.
FailNow
()
t
.
FailNow
()
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录