Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0babf84b
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0babf84b
编写于
7年前
作者:
H
Helin Wang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
implement pserver RPC part, and simple parameter partition.
上级
7e93921a
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
492 addition
and
143 deletion
+492
-143
doc/design/cluster_train/pserver_client.md
doc/design/cluster_train/pserver_client.md
+3
-0
go/pserver/cclient/cclient.go
go/pserver/cclient/cclient.go
+32
-14
go/pserver/client.go
go/pserver/client.go
+189
-11
go/pserver/optimizer.go
go/pserver/optimizer.go
+2
-2
go/pserver/service.go
go/pserver/service.go
+20
-71
go/pserver/service_test.go
go/pserver/service_test.go
+29
-45
paddle/go/pserver/client_test.go
paddle/go/pserver/client_test.go
+123
-0
paddle/go/pserver/internal/connection/conn.go
paddle/go/pserver/internal/connection/conn.go
+84
-0
paddle/go/pserver/partitioner.go
paddle/go/pserver/partitioner.go
+10
-0
未找到文件。
doc/design/cluster_train/pserver_client.md
浏览文件 @
0babf84b
...
...
@@ -136,6 +136,9 @@ int paddle_send_grads(paddle_pserver_client* client, const paddle_gradient* grad
/**
* @brief paddle_get_params gets parameters from parameter servers.
*
* paddle_get_params will block until parameters are initialized on
* the parameter servers.
*
* @param names the array of names of the parameters to get.
* @param dst the destination array of parameters to save to.
* @param len the length of the names array and the paddle_parameter
...
...
This diff is collapsed.
Click to expand it.
go/pserver/cclient/cclient.go
浏览文件 @
0babf84b
...
...
@@ -39,6 +39,7 @@ import "C"
import
(
"log"
"strings"
"sync"
"unsafe"
...
...
@@ -86,29 +87,46 @@ func cArrayToSlice(p unsafe.Pointer, len int) []byte {
return
(
*
[
1
<<
30
]
byte
)(
p
)[
:
len
:
len
]
}
type
selector
bool
func
(
s
selector
)
Select
()
bool
{
return
bool
(
s
)
}
type
lister
[]
pserver
.
Server
func
(
l
lister
)
List
()
[]
pserver
.
Server
{
return
l
}
//export paddle_new_pserver_client
func
paddle_new_pserver_client
(
addr
*
C
.
char
)
C
.
client
{
c
:=
pserver
.
NewClient
(
C
.
GoString
(
addr
))
func
paddle_new_pserver_client
(
addrs
*
C
.
char
,
selected
bool
)
C
.
client
{
a
:=
C
.
GoString
(
addrs
)
as
:=
strings
.
Split
(
a
,
","
)
servers
:=
make
([]
pserver
.
Server
,
len
(
as
))
for
i
:=
range
as
{
servers
[
i
]
.
Index
=
i
servers
[
i
]
.
Addr
=
as
[
i
]
}
c
:=
pserver
.
NewClient
(
lister
(
servers
),
len
(
as
),
selector
(
selected
))
return
add
(
c
)
}
//export paddle_new_etcd_pserver_client
func
paddle_new_etcd_pserver_client
(
etcd_addr
*
C
.
char
)
C
.
client
{
// TODO(helin): fault tolerant pserver client using etcd.
panic
(
"not implemented."
)
}
//export paddle_pserver_client_release
func
paddle_pserver_client_release
(
client
C
.
client
)
{
c
:=
remove
(
client
)
c
.
Cleanup
()
remove
(
client
)
}
//export paddle_begin_init_params
func
paddle_begin_init_params
(
client
C
.
client
,
pserver_config
unsafe
.
Pointer
,
config_len
C
.
int
)
C
.
int
{
func
paddle_begin_init_params
(
client
C
.
client
)
C
.
int
{
c
:=
get
(
client
)
b
:=
cArrayToSlice
(
pserver_config
,
int
(
config_len
))
selected
,
err
:=
c
.
BeginInitParams
(
b
)
if
err
!=
nil
{
log
.
Println
(
err
)
return
-
1
}
if
selected
{
if
selected
:=
c
.
BeginInitParams
();
selected
{
return
1
}
return
0
...
...
@@ -230,7 +248,7 @@ func paddle_get_params(client C.client, names **C.char, dst **C.paddle_parameter
func
paddle_save_model
(
client
C
.
client
,
path
*
C
.
char
)
C
.
int
{
p
:=
C
.
GoString
(
path
)
c
:=
get
(
client
)
err
:=
c
.
Save
Model
(
p
)
err
:=
c
.
Save
(
p
)
if
err
!=
nil
{
log
.
Println
(
err
)
return
-
1
...
...
This diff is collapsed.
Click to expand it.
go/pserver/client.go
浏览文件 @
0babf84b
package
pserver
import
(
"hash/fnv"
"log"
"sort"
"time"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver/internal/connection"
)
// TODO(helin): add RPC call retry logic
// Selector selects if the client should initialize parameter servers.
type
Selector
interface
{
Select
()
bool
}
// Server is the identification of a parameter Server.
type
Server
struct
{
Index
int
Addr
string
}
// Lister lists currently available parameter servers.
type
Lister
interface
{
List
()
[]
Server
}
// Client is the client to parameter servers.
type
Client
struct
{
sel
Selector
pservers
[]
*
connection
.
Conn
}
// NewClient creates a new client.
func
NewClient
(
addr
string
)
*
Client
{
return
&
Client
{}
func
NewClient
(
l
Lister
,
pserverNum
int
,
sel
Selector
)
*
Client
{
c
:=
&
Client
{
sel
:
sel
}
c
.
pservers
=
make
([]
*
connection
.
Conn
,
pserverNum
)
for
i
:=
0
;
i
<
pserverNum
;
i
++
{
c
.
pservers
[
i
]
=
connection
.
New
()
}
go
c
.
monitorPservers
(
l
,
pserverNum
)
return
c
}
// monitorPservers monitors pserver addresses, and updates connection
// when the address changes.
func
(
c
*
Client
)
monitorPservers
(
l
Lister
,
pserverNum
int
)
{
knownServers
:=
make
([]
Server
,
pserverNum
)
ticker
:=
time
.
NewTicker
(
10
*
time
.
Second
)
monitor
:=
func
()
{
curServers
:=
make
([]
Server
,
pserverNum
)
list
:=
l
.
List
()
for
_
,
l
:=
range
list
{
curServers
[
l
.
Index
]
=
l
}
for
i
:=
range
knownServers
{
if
knownServers
[
i
]
.
Addr
!=
curServers
[
i
]
.
Addr
{
err
:=
c
.
pservers
[
i
]
.
Connect
(
curServers
[
i
]
.
Addr
)
if
err
!=
nil
{
log
.
Println
(
err
)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers
[
i
]
.
Addr
=
knownServers
[
i
]
.
Addr
}
}
}
knownServers
=
curServers
}
monitor
()
for
_
=
range
ticker
.
C
{
monitor
()
}
}
// BeginInitParams begins to initialize parameters on parameter
...
...
@@ -17,38 +87,146 @@ func NewClient(addr string) *Client {
// servers. Other trainers will be blocked until the initialization is
// done, and they need to get the initialized parameters from
// parameter servers using GetParams.
func
(
c
*
Client
)
BeginInitParams
(
pserverConfigProto
[]
byte
)
(
selected
bool
,
err
error
)
{
return
true
,
nil
func
(
c
*
Client
)
BeginInitParams
(
)
bool
{
return
c
.
sel
.
Select
()
}
// InitParam initializes the parameter on parameter servers.
func
(
c
*
Client
)
InitParam
(
paramWithConfigs
ParameterWithConfig
)
error
{
return
nil
var
dummy
int
return
c
.
pservers
[
c
.
partition
(
paramWithConfigs
.
Param
.
Name
)]
.
Call
(
"Service.InitParam"
,
paramWithConfigs
,
&
dummy
)
}
// FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization.
func
(
c
*
Client
)
FinishInitParams
()
error
{
for
_
,
p
:=
range
c
.
pservers
{
var
dummy
int
err
:=
p
.
Call
(
"Service.FinishInitParams"
,
dummy
,
&
dummy
)
if
err
!=
nil
{
return
err
}
}
return
nil
}
// SendGrads sends gradients to parameter servers for updating
// parameters.
func
(
c
*
Client
)
SendGrads
(
grads
[]
Gradient
)
error
{
errCh
:=
make
(
chan
error
,
len
(
grads
))
for
_
,
g
:=
range
grads
{
go
func
(
g
Gradient
)
{
var
dummy
int
err
:=
c
.
pservers
[
c
.
partition
(
g
.
Name
)]
.
Call
(
"Service.SendGrad"
,
g
,
&
dummy
)
errCh
<-
err
}(
g
)
}
recv
:=
0
for
err
:=
range
errCh
{
if
err
!=
nil
{
return
err
}
recv
++
if
recv
==
len
(
grads
)
{
break
}
}
return
nil
}
type
result
struct
{
idx
int
p
Parameter
err
error
}
type
results
[]
result
func
(
r
results
)
Len
()
int
{
return
len
(
r
)
}
func
(
r
results
)
Less
(
i
int
,
j
int
)
bool
{
return
r
[
i
]
.
idx
<
r
[
j
]
.
idx
}
func
(
r
results
)
Swap
(
i
int
,
j
int
)
{
r
[
i
],
r
[
j
]
=
r
[
j
],
r
[
i
]
}
// GetParams gets parameters from parameter servers.
func
(
c
*
Client
)
GetParams
(
names
[]
string
)
([]
Parameter
,
error
)
{
return
nil
,
nil
rCh
:=
make
(
chan
result
,
len
(
names
))
for
idx
,
name
:=
range
names
{
go
func
(
name
string
,
idx
int
)
{
var
parameter
Parameter
err
:=
c
.
pservers
[
c
.
partition
(
name
)]
.
Call
(
"Service.GetParam"
,
name
,
&
parameter
)
rCh
<-
result
{
idx
:
idx
,
p
:
parameter
,
err
:
err
}
}(
name
,
idx
)
}
var
rs
results
recv
:=
0
for
r
:=
range
rCh
{
if
r
.
err
!=
nil
{
return
nil
,
r
.
err
}
rs
=
append
(
rs
,
r
)
recv
++
if
recv
==
len
(
names
)
{
break
}
}
sort
.
Sort
(
rs
)
ps
:=
make
([]
Parameter
,
len
(
rs
))
for
i
:=
range
rs
{
ps
[
i
]
=
rs
[
i
]
.
p
}
return
ps
,
nil
}
// SaveModel indicates parameters to save the parameter to the given
// path.
func
(
c
*
Client
)
SaveModel
(
path
string
)
error
{
// Save indicates parameters to save the parameter to the given path.
func
(
c
*
Client
)
Save
(
path
string
)
error
{
errCh
:=
make
(
chan
error
,
len
(
c
.
pservers
))
for
_
,
p
:=
range
c
.
pservers
{
var
dummy
int
err
:=
p
.
Call
(
"Service.Save"
,
path
,
&
dummy
)
errCh
<-
err
}
recv
:=
0
for
err
:=
range
errCh
{
if
err
!=
nil
{
return
err
}
recv
++
if
recv
==
len
(
c
.
pservers
)
{
break
}
}
// TODO(helin): there will be many files under path, need to
// merge them into a single file.
return
nil
}
// Cleanup cleans up the client states.
func
(
c
*
Client
)
Cleanup
()
{
func
strHash
(
s
string
)
uint32
{
h
:=
fnv
.
New32a
()
h
.
Write
([]
byte
(
s
))
return
h
.
Sum32
()
}
// TODO(helin): now partition only select which parameter server to
// send the entire parameter. We need to partition a parameter into
// small blocks and send to different parameter servers.
func
(
c
*
Client
)
partition
(
key
string
)
int
{
return
int
(
strHash
(
key
)
%
uint32
(
len
(
c
.
pservers
)))
}
This diff is collapsed.
Click to expand it.
go/pserver/optimizer.go
浏览文件 @
0babf84b
...
...
@@ -29,11 +29,11 @@ func newOptimizer(t optimizerType, learning_rate float64) *optimizer {
func
(
o
*
optimizer
)
UpdateParameter
(
p
Parameter
,
g
Gradient
)
error
{
if
len
(
p
.
Content
)
!=
len
(
g
.
Content
)
{
return
fmt
.
Errorf
(
"
parameter and gradient length not match, parameter: %d, gradient: %d"
,
len
(
p
.
Content
),
len
(
g
.
Content
))
return
fmt
.
Errorf
(
"
Name: %s, parameter and gradient length not match, parameter: %d, gradient: %d"
,
p
.
Name
,
len
(
p
.
Content
),
len
(
g
.
Content
))
}
if
p
.
ElementType
!=
g
.
ElementType
{
return
fmt
.
Errorf
(
"
parameter and gradient element type not match, parameter: %v, gradient: %v"
,
p
.
ElementType
,
g
.
ElementType
)
return
fmt
.
Errorf
(
"
Name: %s, parameter and gradient element type not match, parameter: %v, gradient: %v"
,
p
.
Name
,
p
.
ElementType
,
g
.
ElementType
)
}
r
:=
C
.
paddle_update_parameter
(
o
.
opt
,
unsafe
.
Pointer
(
&
p
.
Content
[
0
]),
C
.
paddle_element_type
(
p
.
ElementType
),
unsafe
.
Pointer
(
&
g
.
Content
[
0
]),
C
.
int
(
len
(
g
.
Content
)))
...
...
This diff is collapsed.
Click to expand it.
go/pserver/service.go
浏览文件 @
0babf84b
...
...
@@ -49,33 +49,12 @@ type Service struct {
// NewService creates a new service.
func
NewService
()
*
Service
{
s
:=
&
Service
{}
s
:=
&
Service
{
opt
:
newOptimizer
(
sgd
,
0.01
)
}
s
.
paramMap
=
make
(
map
[
string
]
Parameter
)
s
.
initialized
=
make
(
chan
struct
{})
return
s
}
// BeginInitParams tells the parameter server that the parameter
// initialization has begun.
func
(
s
*
Service
)
BeginInitParams
(
config
[]
byte
,
dummy
*
int
)
error
{
select
{
case
<-
s
.
initialized
:
return
ErrAlreadyInitialized
default
:
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
if
s
.
opt
!=
nil
{
s
.
opt
.
Cleanup
()
}
// TODO(helin): parse learning rate from config
s
.
opt
=
newOptimizer
(
sgd
,
0.01
)
return
nil
}
// InitParam initializes a parameter.
func
(
s
*
Service
)
InitParam
(
paramWithConfigs
ParameterWithConfig
,
dummy
*
int
)
error
{
select
{
...
...
@@ -109,75 +88,45 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
return
nil
}
// SendGrad
s sends gradients
to parameter servers for parameter
// SendGrad
sends gradient
to parameter servers for parameter
// optimization.
func
(
s
*
Service
)
SendGrad
s
(
grads
[]
Gradient
,
dummy
*
int
)
error
{
func
(
s
*
Service
)
SendGrad
(
g
Gradient
,
dummy
*
int
)
error
{
select
{
case
<-
s
.
initialized
:
default
:
return
ErrUninitialized
}
count
:=
len
(
grads
)
if
count
==
0
{
return
nil
}
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
for
_
,
g
:=
range
grads
{
if
_
,
ok
:=
s
.
paramMap
[
g
.
Name
];
!
ok
{
return
fmt
.
Errorf
(
"parameter: %s does not exist"
,
g
.
Name
)
}
}
errCh
:=
make
(
chan
error
,
count
)
for
_
,
g
:=
range
grads
{
go
func
(
p
Parameter
,
g
Gradient
)
{
err
:=
s
.
opt
.
UpdateParameter
(
p
,
g
)
errCh
<-
err
}(
s
.
paramMap
[
g
.
Name
],
g
)
p
,
ok
:=
s
.
paramMap
[
g
.
Name
]
if
!
ok
{
return
fmt
.
Errorf
(
"parameter: %s does not exist"
,
g
.
Name
)
}
recv
:=
0
for
err
:=
range
errCh
{
if
err
!=
nil
{
return
err
}
recv
++
if
recv
==
count
{
break
}
}
return
nil
return
s
.
opt
.
UpdateParameter
(
p
,
g
)
}
// GetParam
s
gets parameters from the parameter server.
func
(
s
*
Service
)
GetParam
s
(
names
[]
string
,
parameters
*
[]
Parameter
)
error
{
// GetParam gets parameters from the parameter server.
func
(
s
*
Service
)
GetParam
(
name
string
,
parameter
*
Parameter
)
error
{
<-
s
.
initialized
s
.
mu
.
Lock
()
defer
s
.
mu
.
Unlock
()
for
_
,
n
:=
range
names
{
if
_
,
ok
:=
s
.
paramMap
[
n
];
!
ok
{
return
fmt
.
Errorf
(
"parameter: %s does not exist"
,
n
)
}
}
*
parameters
=
make
([]
Parameter
,
len
(
names
))
for
i
,
n
:=
range
names
{
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
// learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
(
*
parameters
)[
i
]
=
s
.
paramMap
[
n
]
p
,
ok
:=
s
.
paramMap
[
name
]
if
!
ok
{
return
fmt
.
Errorf
(
"parameter: %s does not exist"
,
name
)
}
// The parameter content (a byte slice) may change
// during RPC serialization due to write from other
// goroutine, we allow it since mini-batch based deep
// learning optimization methods are stochastic in
// nature. This race condition is allowed deliberately
// to save the program from making a copy of the
// paramter content.
*
parameter
=
p
return
nil
}
...
...
This diff is collapsed.
Click to expand it.
go/pserver/service_test.go
浏览文件 @
0babf84b
...
...
@@ -4,23 +4,19 @@ import (
"reflect"
"sync"
"testing"
"time"
"github.com/PaddlePaddle/Paddle/go/pserver"
)
func
TestFull
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
var
dummy
int
err
:=
s
.
BeginInitParams
(
nil
,
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
}
var
p
pserver
.
Parameter
p
.
Name
=
"param_a"
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
ElementType
=
pserver
.
Int32
err
=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p
,
nil
},
&
dummy
)
var
dummy
int
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p
,
nil
},
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
}
...
...
@@ -39,40 +35,39 @@ func TestFull(t *testing.T) {
t
.
FailNow
()
}
var
param
s
[]
pserver
.
Parameter
err
=
s
.
GetParam
s
([]
string
{
"param_b"
,
"param_a"
},
&
params
)
var
param
pserver
.
Parameter
err
=
s
.
GetParam
(
"param_b"
,
&
param
)
if
err
!=
nil
{
t
.
FailNow
()
}
if
len
(
params
)
!=
2
||
!
reflect
.
DeepEqual
(
params
[
0
],
p1
)
||
!
reflect
.
DeepEqual
(
params
[
0
]
,
p1
)
{
if
!
reflect
.
DeepEqual
(
param
,
p1
)
{
t
.
FailNow
()
}
g
rads
:=
[]
pserver
.
Gradient
{
pserver
.
Gradient
(
p1
),
pserver
.
Gradient
(
p
)}
err
=
s
.
SendGrad
s
(
grads
,
&
dummy
)
g
1
,
g2
:=
pserver
.
Gradient
(
p1
),
pserver
.
Gradient
(
p
)
err
=
s
.
SendGrad
(
g1
,
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
}
err
=
s
.
SendGrad
(
g2
,
&
dummy
)
var
params1
[]
pserver
.
Parameter
err
=
s
.
GetParams
([]
string
{
"param_b"
,
"param_a"
},
&
params1
)
if
err
!=
nil
{
t
.
FailNow
()
}
if
len
(
params
)
!=
2
{
var
param1
pserver
.
Parameter
err
=
s
.
GetParam
(
"param_a"
,
&
param1
)
if
err
!=
nil
{
t
.
FailNow
()
}
// don't compare content, since it's already changed by
// gradient update.
params1
[
0
]
.
Content
=
nil
params1
[
0
]
.
Content
=
nil
param1
.
Content
=
nil
p
.
Content
=
nil
p1
.
Content
=
nil
if
!
reflect
.
DeepEqual
(
param
s1
[
0
],
p1
)
||
!
reflect
.
DeepEqual
(
params1
[
0
],
p1
)
{
if
!
reflect
.
DeepEqual
(
param
1
,
p
)
{
t
.
FailNow
()
}
}
...
...
@@ -80,19 +75,7 @@ func TestFull(t *testing.T) {
func
TestMultipleInit
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
var
dummy
int
err
:=
s
.
BeginInitParams
(
nil
,
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
}
// this is fine, it's possible for client to call init
// multiple times.
err
=
s
.
BeginInitParams
(
nil
,
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
}
err
=
s
.
FinishInitParams
(
0
,
&
dummy
)
err
:=
s
.
FinishInitParams
(
0
,
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
}
...
...
@@ -101,17 +84,12 @@ func TestMultipleInit(t *testing.T) {
if
err
!=
pserver
.
ErrAlreadyInitialized
{
t
.
FailNow
()
}
err
=
s
.
BeginInitParams
(
nil
,
&
dummy
)
if
err
!=
pserver
.
ErrAlreadyInitialized
{
t
.
FailNow
()
}
}
func
TestUninitialized
(
t
*
testing
.
T
)
{
s
:=
pserver
.
NewService
()
var
dummy
int
err
:=
s
.
SendGrad
s
(
nil
,
&
dummy
)
err
:=
s
.
SendGrad
(
pserver
.
Gradient
{}
,
&
dummy
)
if
err
!=
pserver
.
ErrUninitialized
{
t
.
FailNow
()
}
...
...
@@ -123,8 +101,8 @@ func TestBlockUntilInitialized(t *testing.T) {
var
wg
sync
.
WaitGroup
wg
.
Add
(
1
)
go
func
()
{
var
param
s
[]
pserver
.
Parameter
err
:=
s
.
GetParam
s
(
nil
,
&
params
)
var
param
pserver
.
Parameter
err
:=
s
.
GetParam
(
"param_a"
,
&
param
)
if
err
!=
nil
{
t
.
FailNow
()
}
...
...
@@ -143,11 +121,7 @@ func TestBlockUntilInitialized(t *testing.T) {
ch
<-
struct
{}{}
}()
var
dummy
int
err
:=
s
.
BeginInitParams
(
nil
,
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
}
time
.
Sleep
(
50
*
time
.
Millisecond
)
select
{
case
<-
ch
:
...
...
@@ -156,6 +130,16 @@ func TestBlockUntilInitialized(t *testing.T) {
default
:
}
var
p
pserver
.
Parameter
p
.
Name
=
"param_a"
p
.
Content
=
[]
byte
{
1
,
0
,
0
,
0
,
2
,
0
,
0
,
0
,
3
,
0
,
0
,
0
}
p
.
ElementType
=
pserver
.
Int32
var
dummy
int
err
:=
s
.
InitParam
(
pserver
.
ParameterWithConfig
{
p
,
nil
},
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
}
err
=
s
.
FinishInitParams
(
0
,
&
dummy
)
if
err
!=
nil
{
t
.
FailNow
()
...
...
This diff is collapsed.
Click to expand it.
paddle/go/pserver/client_test.go
0 → 100644
浏览文件 @
0babf84b
package
pserver_test
import
(
"net"
"net/http"
"net/rpc"
"strconv"
"strings"
"testing"
"github.com/PaddlePaddle/Paddle/paddle/go/pserver"
)
const
numPserver
=
10
var
port
[
numPserver
]
int
func
init
()
{
for
i
:=
0
;
i
<
numPserver
;
i
++
{
l
,
err
:=
net
.
Listen
(
"tcp"
,
":0"
)
if
err
!=
nil
{
panic
(
err
)
}
ss
:=
strings
.
Split
(
l
.
Addr
()
.
String
(),
":"
)
p
,
err
:=
strconv
.
Atoi
(
ss
[
len
(
ss
)
-
1
])
if
err
!=
nil
{
panic
(
err
)
}
port
[
i
]
=
p
go
func
(
l
net
.
Listener
)
{
s
:=
pserver
.
NewService
()
server
:=
rpc
.
NewServer
()
err
:=
server
.
Register
(
s
)
if
err
!=
nil
{
panic
(
err
)
}
mux
:=
http
.
NewServeMux
()
mux
.
Handle
(
rpc
.
DefaultRPCPath
,
server
)
err
=
http
.
Serve
(
l
,
mux
)
if
err
!=
nil
{
panic
(
err
)
}
}(
l
)
}
}
type
selector
bool
func
(
s
selector
)
Select
()
bool
{
return
bool
(
s
)
}
type
lister
[]
pserver
.
Server
func
(
l
lister
)
List
()
[]
pserver
.
Server
{
return
l
}
func
TestClientFull
(
t
*
testing
.
T
)
{
servers
:=
make
([]
pserver
.
Server
,
numPserver
)
for
i
:=
0
;
i
<
numPserver
;
i
++
{
servers
[
i
]
=
pserver
.
Server
{
Index
:
i
,
Addr
:
":"
+
strconv
.
Itoa
(
port
[
i
])}
}
c
:=
pserver
.
NewClient
(
lister
(
servers
),
len
(
servers
),
selector
(
true
))
selected
:=
c
.
BeginInitParams
()
if
!
selected
{
t
.
Fatal
(
"should be selected."
)
}
const
numParameter
=
100
for
i
:=
0
;
i
<
numParameter
;
i
++
{
var
p
pserver
.
Parameter
p
.
Name
=
"p_"
+
strconv
.
Itoa
(
i
)
p
.
ElementType
=
pserver
.
Float32
p
.
Content
=
make
([]
byte
,
(
i
+
1
)
*
100
)
err
:=
c
.
InitParam
(
pserver
.
ParameterWithConfig
{
Param
:
p
})
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
}
err
:=
c
.
FinishInitParams
()
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
var
grads
[]
pserver
.
Gradient
for
i
:=
0
;
i
<
numParameter
/
2
;
i
++
{
var
g
pserver
.
Gradient
g
.
Name
=
"p_"
+
strconv
.
Itoa
(
i
)
g
.
ElementType
=
pserver
.
Float32
g
.
Content
=
make
([]
byte
,
(
i
+
1
)
*
100
)
grads
=
append
(
grads
,
g
)
}
err
=
c
.
SendGrads
(
grads
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
names
:=
make
([]
string
,
numParameter
)
for
i
:=
0
;
i
<
numParameter
;
i
++
{
names
[
i
]
=
"p_"
+
strconv
.
Itoa
(
i
)
}
params
,
err
:=
c
.
GetParams
(
names
)
if
err
!=
nil
{
t
.
Fatal
(
err
)
}
if
len
(
names
)
!=
len
(
params
)
{
t
.
Fatalf
(
"parameter size not match, need: %d, have: %d"
,
len
(
names
),
len
(
params
))
}
for
i
:=
range
params
{
if
names
[
i
]
!=
params
[
i
]
.
Name
{
t
.
Fatalf
(
"order of returned parameter does not required: parameter name: %s, required name: %s"
,
names
[
i
],
params
[
i
])
}
}
}
This diff is collapsed.
Click to expand it.
paddle/go/pserver/internal/connection/conn.go
0 → 100644
浏览文件 @
0babf84b
package
connection
import
(
"errors"
"net/rpc"
"sync"
)
// TODO(helin): add TCP re-connect logic
// Conn is a connection to a parameter server
type
Conn
struct
{
mu
sync
.
Mutex
client
*
rpc
.
Client
waitConn
chan
struct
{}
}
// New creates a new connection.
func
New
()
*
Conn
{
c
:=
&
Conn
{}
return
c
}
// Connect connects the connection to a address.
func
(
c
*
Conn
)
Connect
(
addr
string
)
error
{
c
.
mu
.
Lock
()
if
c
.
client
!=
nil
{
err
:=
c
.
client
.
Close
()
if
err
!=
nil
{
c
.
mu
.
Unlock
()
return
err
}
c
.
client
=
nil
}
c
.
mu
.
Unlock
()
client
,
err
:=
rpc
.
DialHTTP
(
"tcp"
,
addr
)
if
err
!=
nil
{
return
err
}
c
.
mu
.
Lock
()
defer
c
.
mu
.
Unlock
()
if
c
.
client
==
nil
{
c
.
client
=
client
if
c
.
waitConn
!=
nil
{
close
(
c
.
waitConn
)
c
.
waitConn
=
nil
}
}
else
{
return
errors
.
New
(
"client already set from a concurrent goroutine"
)
}
return
nil
}
// Call make a RPC call.
//
// Call will be blocked until the connection to remote RPC service
// being established.
func
(
c
*
Conn
)
Call
(
serviceMethod
string
,
args
interface
{},
reply
interface
{})
error
{
c
.
mu
.
Lock
()
client
:=
c
.
client
var
waitCh
chan
struct
{}
if
client
==
nil
{
if
c
.
waitConn
!=
nil
{
waitCh
=
c
.
waitConn
}
else
{
waitCh
=
make
(
chan
struct
{})
c
.
waitConn
=
waitCh
}
}
c
.
mu
.
Unlock
()
if
waitCh
!=
nil
{
// wait until new connection being established
<-
waitCh
return
c
.
Call
(
serviceMethod
,
args
,
reply
)
}
return
client
.
Call
(
serviceMethod
,
args
,
reply
)
}
This diff is collapsed.
Click to expand it.
paddle/go/pserver/partitioner.go
0 → 100644
浏览文件 @
0babf84b
package
pserver
type
partitioner
struct
{
shardNum
int
}
// partitioner partitions the parameters into shards.
func
newPartitioner
(
shardNum
int
)
*
partitioner
{
return
&
partitioner
{
shardNum
:
shardNum
}
}
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录