提交 5f5e128d 编写于 作者: H helinwang 提交者: GitHub

Merge pull request #2429 from helinwang/master_client

implement master server client, remove unnecessary dummy variable
package main package main
import ( import (
"fmt"
"net" "net"
"net/http" "net/http"
"net/rpc" "net/rpc"
"os"
"path/filepath"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/namsral/flag" "github.com/namsral/flag"
"github.com/PaddlePaddle/Paddle/go/master" "github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
) )
func main() { func main() {
port := flag.Int("port", 8080, "port of the master server.") port := flag.Int("port", 8080, "port of the master server.")
dataset := flag.String("training_dataset", "", "dataset: comma separated path to RecordIO paths, supports golb patterns.")
faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).") faultTolerance := flag.Bool("fault_tolerance", false, "enable fault tolerance (requires etcd).")
taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.") taskTimeoutDur := flag.Duration("task_timout_dur", 20*time.Minute, "task timout duration.")
taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.") taskTimeoutMax := flag.Int("task_timeout_max", 3, "max timtout count for each task before it being declared failed task.")
chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.") chunkPerTask := flag.Int("chunk_per_task", 10, "chunk per task.")
flag.Parse() flag.Parse()
if *dataset == "" {
panic("no dataset specified.")
}
if *faultTolerance { if *faultTolerance {
panic("fault tolernance not implemented.") panic("fault tolernance not implemented.")
}
var chunks []master.Chunk
var paths []string
ss := strings.Split(*dataset, ",")
fmt.Println(ss)
for _, s := range ss {
match, err := filepath.Glob(s)
if err != nil {
panic(err)
}
paths = append(paths, match...)
}
if len(paths) == 0 {
panic("no valid datset specified.")
}
idx := 0
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
panic(err)
}
index, err := recordio.LoadIndex(f)
if err != nil {
panic(err)
}
f.Close()
count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := master.Chunk{
Idx: idx,
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
} }
s := master.NewService(chunks, *chunkPerTask, *taskTimeoutDur, *taskTimeoutMax) s := master.NewService(*chunkPerTask, *taskTimeoutDur, *taskTimeoutMax)
err := rpc.Register(s) err := rpc.Register(s)
if err != nil { if err != nil {
panic(err) panic(err)
......
...@@ -2,6 +2,7 @@ package connection ...@@ -2,6 +2,7 @@ package connection
import ( import (
"errors" "errors"
"log"
"net/rpc" "net/rpc"
"sync" "sync"
) )
...@@ -21,6 +22,18 @@ func New() *Conn { ...@@ -21,6 +22,18 @@ func New() *Conn {
return c return c
} }
// Close closes the connection.
func (c *Conn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
if c.client == nil {
return nil
}
return c.client.Close()
}
// Connect connects the connection to a address. // Connect connects the connection to a address.
func (c *Conn) Connect(addr string) error { func (c *Conn) Connect(addr string) error {
c.mu.Lock() c.mu.Lock()
...@@ -50,12 +63,20 @@ func (c *Conn) Connect(addr string) error { ...@@ -50,12 +63,20 @@ func (c *Conn) Connect(addr string) error {
c.waitConn = nil c.waitConn = nil
} }
} else { } else {
err := client.Close()
if err != nil {
log.Println(err)
}
return errors.New("client already set from a concurrent goroutine") return errors.New("client already set from a concurrent goroutine")
} }
return nil return nil
} }
// TODO(helin): refactor Call to be able to perform given retry
// policy.
// Call make a RPC call. // Call make a RPC call.
// //
// Call will be blocked until the connection to remote RPC service // Call will be blocked until the connection to remote RPC service
......
package master
import (
"log"
"time"
"github.com/PaddlePaddle/Paddle/go/connection"
)
// Addresser provide the address of the master server.
type Addresser interface {
Address() string
}
// Client is the client of the master server.
type Client struct {
conn *connection.Conn
}
// NewClient creates a new Client.
func NewClient(addr Addresser) *Client {
c := &Client{}
c.conn = connection.New()
go c.monitorMaster(addr)
return c
}
func (c *Client) monitorMaster(addr Addresser) {
lastMaster := ""
monitor := func() {
// get the lastest address of the master server,
// connect to the new address once address changed.
curMaster := addr.Address()
if curMaster != lastMaster {
if curMaster == "" {
err := c.conn.Close()
if err != nil {
log.Println(err)
}
} else {
err := c.conn.Connect(curMaster)
if err != nil {
log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curMaster = lastMaster
}
}
}
lastMaster = curMaster
}
monitor()
ticker := time.NewTicker(10 * time.Second)
for _ = range ticker.C {
monitor()
}
}
// SetDataset set dataset for the master server to dispatch.
//
// SetDataset can be call multiple times from different nodes. But
// only the first call will be honored.
func (c *Client) SetDataset(globPaths []string) error {
return c.conn.Call("Service.SetDataset", globPaths, nil)
}
// GetTask gets a new task from the master server.
func (c *Client) GetTask() (Task, error) {
var t Task
err := c.conn.Call("Service.GetTask", 0, &t)
return t, err
}
// TaskFinished tells the master server a task is finished.
func (c *Client) TaskFinished(taskID int) error {
return c.conn.Call("Service.TaskFinished", taskID, nil)
}
package master_test
import (
"fmt"
"net"
"net/http"
"net/rpc"
"os"
"strconv"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/PaddlePaddle/Paddle/go/master"
"github.com/PaddlePaddle/recordio"
)
const (
totalTask = 20
chunkPerTask = 10
)
var port int
func init() {
log.SetLevel(log.ErrorLevel)
l, err := net.Listen("tcp", ":0")
if err != nil {
panic(err)
}
ss := strings.Split(l.Addr().String(), ":")
p, err := strconv.Atoi(ss[len(ss)-1])
if err != nil {
panic(err)
}
port = p
go func(l net.Listener) {
s := master.NewService(chunkPerTask, time.Second, 1)
server := rpc.NewServer()
err := server.Register(s)
if err != nil {
panic(err)
}
mux := http.NewServeMux()
mux.Handle(rpc.DefaultRPCPath, server)
err = http.Serve(l, mux)
if err != nil {
panic(err)
}
}(l)
}
type addresser string
func (a addresser) Address() string {
return string(a)
}
func TestClientFull(t *testing.T) {
const p = "/tmp/master_client_test_0"
f, err := os.Create(p)
if err != nil {
panic(err)
}
for i := 0; i < totalTask*chunkPerTask; i++ {
w := recordio.NewWriter(f, -1, -1)
w.Write(nil)
// call Close to force RecordIO writing a chunk.
w.Close()
}
f.Close()
c := master.NewClient(addresser(fmt.Sprintf(":%d", port)))
c.SetDataset([]string{p})
checkOnePass := func(i int) {
var tasks []master.Task
for i := 0; i < totalTask; i++ {
task, err := c.GetTask()
if err != nil {
t.Fatal(i, err)
}
tasks = append(tasks, task)
}
_, err = c.GetTask()
if err == nil {
t.Fatal(i, "should get error.")
}
err = c.TaskFinished(tasks[0].ID)
if err != nil {
t.Fatal(err)
}
tasks = tasks[1:]
task, err := c.GetTask()
if err != nil {
t.Fatal(err)
}
tasks = append(tasks, task)
for _, task := range tasks {
err = c.TaskFinished(task.ID)
if err != nil {
t.Fatal(i, err)
}
}
}
for i := 0; i < 10; i++ {
checkOnePass(i)
}
}
...@@ -2,29 +2,25 @@ package master ...@@ -2,29 +2,25 @@ package master
import ( import (
"errors" "errors"
"log" "os"
"path/filepath"
"sync" "sync"
"time" "time"
"github.com/PaddlePaddle/recordio" log "github.com/sirupsen/logrus"
)
const ( "github.com/PaddlePaddle/recordio"
targetTaskCount = 300
)
// errors
var (
ErrNoMoreTask = errors.New("no more task for current pass")
ErrPendingTaskNotFound = errors.New("pending task not found")
) )
// Service is the master server service. // Service is the master server service.
type Service struct { type Service struct {
timeoutDur time.Duration chunksPerTask int
timeoutMax int timeoutDur time.Duration
timeoutMax int
ready chan struct{}
mu sync.Mutex mu sync.Mutex
initDone bool
taskQueues taskQueues taskQueues taskQueues
} }
...@@ -55,7 +51,6 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -55,7 +51,6 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
if len(cur.Task.Chunks) > 0 { if len(cur.Task.Chunks) > 0 {
cur.Task.ID = id cur.Task.ID = id
id++
result = append(result, cur) result = append(result, cur)
} }
...@@ -63,21 +58,21 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry { ...@@ -63,21 +58,21 @@ func partition(chunks []Chunk, chunksPerTask int) []taskEntry {
} }
// NewService creates a new service. // NewService creates a new service.
func NewService(chunks []Chunk, chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service { func NewService(chunksPerTask int, timeoutDur time.Duration, timeoutMax int) *Service {
s := &Service{} s := &Service{}
s.chunksPerTask = chunksPerTask
s.timeoutDur = timeoutDur s.timeoutDur = timeoutDur
s.timeoutMax = timeoutMax s.timeoutMax = timeoutMax
s.taskQueues = taskQueues{} s.taskQueues = taskQueues{}
s.taskQueues.Pending = make(map[int]taskEntry) s.taskQueues.Pending = make(map[int]taskEntry)
s.taskQueues.Todo = partition(chunks, chunksPerTask) s.ready = make(chan struct{})
return s return s
} }
// Chunk is a chunk of data consisted of several data instances. // Chunk is a chunk of data consisted of several data instances.
type Chunk struct { type Chunk struct {
Idx int // index of the chunk within the file
Path string Path string
Index recordio.Index // block index Index recordio.Index // chunk index
} }
// Task is the basic unit of data instances assigned to trainers. // Task is the basic unit of data instances assigned to trainers.
...@@ -105,74 +100,205 @@ func (s *Service) snapshot() error { ...@@ -105,74 +100,205 @@ func (s *Service) snapshot() error {
return nil return nil
} }
// GetTask gets a new task from the service. func readChunks(globPaths []string) ([]Chunk, error) {
func (s *Service) GetTask(dummy int, task *Task) error { var chunks []Chunk
var paths []string
for _, s := range globPaths {
match, err := filepath.Glob(s)
if err != nil {
return nil, err
}
paths = append(paths, match...)
}
if len(paths) == 0 {
return nil, errors.New("no valid dataset specified")
}
for _, path := range paths {
f, err := os.Open(path)
if err != nil {
return nil, err
}
index, err := recordio.LoadIndex(f)
if err != nil {
return nil, err
}
err = f.Close()
if err != nil {
return nil, err
}
count := index.NumChunks()
for i := 0; i < count; i++ {
chunk := Chunk{
Path: path,
Index: *index.ChunkIndex(i),
}
chunks = append(chunks, chunk)
}
}
return chunks, nil
}
// SetDataset sets dataset to dispatch for the master server.
//
// SetDataset can be call multiple times. But only the first call will
// be honored.
func (s *Service) SetDataset(globPaths []string, dummy *int) error {
if len(globPaths) == 0 {
return errors.New("no dataset specified")
}
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.initDone {
// Already initialized. All trainer will call
// SetDataset, but we only handle the first one. Treat
// other calls as successful but do nothing.
return nil
}
if len(s.taskQueues.Todo) == 0 { chunks, err := readChunks(globPaths)
return ErrNoMoreTask if err != nil {
return err
} }
t := s.taskQueues.Todo[0] s.taskQueues.Todo = partition(chunks, s.chunksPerTask)
t.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:] err = s.snapshot()
s.taskQueues.Pending[t.Task.ID] = t
err := s.snapshot()
if err != nil { if err != nil {
log.Errorln(err)
return err return err
} }
time.AfterFunc(s.timeoutDur, func(taskID int, epoch int) func() { close(s.ready)
return func() { s.initDone = true
s.mu.Lock() return nil
defer s.mu.Unlock() }
t, ok := s.taskQueues.Pending[taskID] func (s *Service) checkTimeoutFunc(taskID int, epoch int) func() {
if !ok { return func() {
return s.mu.Lock()
} defer s.mu.Unlock()
if t.Epoch != epoch { t, ok := s.taskQueues.Pending[taskID]
// new epoch, task launched after the if !ok {
// schedule of this timeout check. return
return }
if t.Epoch != epoch {
// new epoch, task launched after the
// schedule of this timeout check.
return
}
defer func() {
err := s.snapshot()
if err != nil {
log.Errorln(err)
} }
}()
delete(s.taskQueues.Pending, t.Task.ID)
defer func() { t.NumTimeout++
err := s.snapshot() if t.NumTimeout > s.timeoutMax {
if err != nil { log.Warningf("Task %v failed %d times, discard.\n", t.Task, t.NumTimeout)
log.Println(err) s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task)
} return
}() }
delete(s.taskQueues.Pending, t.Task.ID) log.Warningf("Task %v failed %d times, retry.\n", t.Task, t.NumTimeout)
s.taskQueues.Todo = append(s.taskQueues.Todo, t)
}
}
t.NumTimeout++ // GetTask gets a new task from the service.
if t.NumTimeout > s.timeoutMax { func (s *Service) GetTask(dummy int, task *Task) error {
s.taskQueues.Failed = append(s.taskQueues.Failed, t.Task) select {
return case <-s.ready:
}
s.mu.Lock()
defer s.mu.Unlock()
if len(s.taskQueues.Todo) == 0 {
if len(s.taskQueues.Done) == 0 {
if len(s.taskQueues.Pending) == 0 {
err := errors.New("all task failed")
log.Warningln(err)
return err
} }
s.taskQueues.Todo = append(s.taskQueues.Todo, t) // TODO(helin): client need to retry in this
// error case. Gotcha: RPC client can't
// compare returned error with predefined
// errors like io.EOF, because the error
// instance deserialized from RPC is a
// different instance than the error defined
// in package. So we need to figure out a way
// for client to check this error correctly.
err := errors.New("no more available task")
log.Warningln(err)
return err
} }
}(t.Task.ID, t.Epoch)) s.taskQueues.Todo = s.taskQueues.Done
s.taskQueues.Done = nil
log.Infoln("No more todo task, but trainer is requesting task to do. Move all done task to todo.")
}
t := s.taskQueues.Todo[0]
t.Epoch++
s.taskQueues.Todo = s.taskQueues.Todo[1:]
s.taskQueues.Pending[t.Task.ID] = t
err := s.snapshot()
if err != nil {
return err
}
*task = t.Task
log.Infof("Task #%d dispatched\n", task.ID)
time.AfterFunc(s.timeoutDur, s.checkTimeoutFunc(t.Task.ID, t.Epoch))
return nil return nil
} }
// TaskFinished tell the service that a task is finished. // TaskFinished tell the service that a task is finished.
func (s *Service) TaskFinished(taskID int, dummy *int) error { func (s *Service) TaskFinished(taskID int, dummy *int) error {
select {
case <-s.ready:
}
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
log.Infof("Task %d finished\n", taskID)
t, ok := s.taskQueues.Pending[taskID] t, ok := s.taskQueues.Pending[taskID]
if !ok { if !ok {
return ErrPendingTaskNotFound err := errors.New("pending task not found")
log.Warningln(err)
return err
} }
// task finished, reset timeout // task finished, reset timeout
t.NumTimeout = 0 t.NumTimeout = 0
s.taskQueues.Done = append(s.taskQueues.Done, t) s.taskQueues.Done = append(s.taskQueues.Done, t)
delete(s.taskQueues.Pending, taskID) delete(s.taskQueues.Pending, taskID)
return s.snapshot()
if len(s.taskQueues.Pending) == 0 && len(s.taskQueues.Todo) == 0 {
log.Infoln("No more todo and pending task, start a new pass.")
s.taskQueues.Todo = append(s.taskQueues.Todo, s.taskQueues.Done...)
s.taskQueues.Done = nil
}
err := s.snapshot()
if err != nil {
log.Errorln(err)
}
return err
} }
...@@ -6,7 +6,7 @@ import ( ...@@ -6,7 +6,7 @@ import (
"sort" "sort"
"time" "time"
"github.com/PaddlePaddle/Paddle/go/pserver/internal/connection" "github.com/PaddlePaddle/Paddle/go/connection"
) )
// TODO(helin): add RPC call retry logic // TODO(helin): add RPC call retry logic
...@@ -47,7 +47,7 @@ func NewClient(l Lister, pserverNum int, sel Selector) *Client { ...@@ -47,7 +47,7 @@ func NewClient(l Lister, pserverNum int, sel Selector) *Client {
// monitorPservers monitors pserver addresses, and updates connection // monitorPservers monitors pserver addresses, and updates connection
// when the address changes. // when the address changes.
func (c *Client) monitorPservers(l Lister, pserverNum int) { func (c *Client) monitorPservers(l Lister, pserverNum int) {
knownServers := make([]Server, pserverNum) lastServers := make([]Server, pserverNum)
ticker := time.NewTicker(10 * time.Second) ticker := time.NewTicker(10 * time.Second)
monitor := func() { monitor := func() {
curServers := make([]Server, pserverNum) curServers := make([]Server, pserverNum)
...@@ -56,25 +56,37 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) { ...@@ -56,25 +56,37 @@ func (c *Client) monitorPservers(l Lister, pserverNum int) {
curServers[l.Index] = l curServers[l.Index] = l
} }
for i := range knownServers { for i := range lastServers {
if knownServers[i].Addr != curServers[i].Addr { if lastServers[i].Addr == curServers[i].Addr {
err := c.pservers[i].Connect(curServers[i].Addr) continue
}
if curServers[i].Addr == "" {
err := c.pservers[i].Close()
if err != nil { if err != nil {
log.Println(err) log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers[i].Addr = knownServers[i].Addr
} }
continue
} }
err := c.pservers[i].Connect(curServers[i].Addr)
if err != nil {
log.Println(err)
// connect to addr failed, set
// to last known addr in order
// to retry next time.
curServers[i].Addr = lastServers[i].Addr
}
} }
knownServers = curServers lastServers = curServers
} }
monitor() monitor()
for _ = range ticker.C { for range ticker.C {
monitor() monitor()
} }
} }
...@@ -93,16 +105,14 @@ func (c *Client) BeginInitParams() bool { ...@@ -93,16 +105,14 @@ func (c *Client) BeginInitParams() bool {
// InitParam initializes the parameter on parameter servers. // InitParam initializes the parameter on parameter servers.
func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error { func (c *Client) InitParam(paramWithConfigs ParameterWithConfig) error {
var dummy int return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, nil)
return c.pservers[c.partition(paramWithConfigs.Param.Name)].Call("Service.InitParam", paramWithConfigs, &dummy)
} }
// FinishInitParams tells parameter servers client has sent all // FinishInitParams tells parameter servers client has sent all
// parameters to parameter servers as initialization. // parameters to parameter servers as initialization.
func (c *Client) FinishInitParams() error { func (c *Client) FinishInitParams() error {
for _, p := range c.pservers { for _, p := range c.pservers {
var dummy int err := p.Call("Service.FinishInitParams", 0, nil)
err := p.Call("Service.FinishInitParams", dummy, &dummy)
if err != nil { if err != nil {
return err return err
} }
...@@ -116,8 +126,7 @@ func (c *Client) SendGrads(grads []Gradient) error { ...@@ -116,8 +126,7 @@ func (c *Client) SendGrads(grads []Gradient) error {
errCh := make(chan error, len(grads)) errCh := make(chan error, len(grads))
for _, g := range grads { for _, g := range grads {
go func(g Gradient) { go func(g Gradient) {
var dummy int err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, nil)
err := c.pservers[c.partition(g.Name)].Call("Service.SendGrad", g, &dummy)
errCh <- err errCh <- err
}(g) }(g)
} }
...@@ -196,8 +205,7 @@ func (c *Client) Save(path string) error { ...@@ -196,8 +205,7 @@ func (c *Client) Save(path string) error {
errCh := make(chan error, len(c.pservers)) errCh := make(chan error, len(c.pservers))
for _, p := range c.pservers { for _, p := range c.pservers {
var dummy int err := p.Call("Service.Save", path, nil)
err := p.Call("Service.Save", path, &dummy)
errCh <- err errCh <- err
} }
......
...@@ -15,8 +15,7 @@ func TestFull(t *testing.T) { ...@@ -15,8 +15,7 @@ func TestFull(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
var dummy int err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -25,12 +24,12 @@ func TestFull(t *testing.T) { ...@@ -25,12 +24,12 @@ func TestFull(t *testing.T) {
p1.Name = "param_b" p1.Name = "param_b"
p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} p1.Content = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}
p1.ElementType = pserver.Float32 p1.ElementType = pserver.Float32
err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, &dummy) err = s.InitParam(pserver.ParameterWithConfig{Param: p1, Config: nil}, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
...@@ -46,11 +45,11 @@ func TestFull(t *testing.T) { ...@@ -46,11 +45,11 @@ func TestFull(t *testing.T) {
} }
g1, g2 := pserver.Gradient(p1), pserver.Gradient(p) g1, g2 := pserver.Gradient(p1), pserver.Gradient(p)
err = s.SendGrad(g1, &dummy) err = s.SendGrad(g1, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.SendGrad(g2, &dummy) err = s.SendGrad(g2, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
...@@ -74,13 +73,12 @@ func TestFull(t *testing.T) { ...@@ -74,13 +73,12 @@ func TestFull(t *testing.T) {
func TestMultipleInit(t *testing.T) { func TestMultipleInit(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var dummy int err := s.FinishInitParams(0, nil)
err := s.FinishInitParams(0, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, nil)
if err.Error() != pserver.AlreadyInitialized { if err.Error() != pserver.AlreadyInitialized {
t.FailNow() t.FailNow()
} }
...@@ -88,8 +86,7 @@ func TestMultipleInit(t *testing.T) { ...@@ -88,8 +86,7 @@ func TestMultipleInit(t *testing.T) {
func TestUninitialized(t *testing.T) { func TestUninitialized(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
var dummy int err := s.SendGrad(pserver.Gradient{}, nil)
err := s.SendGrad(pserver.Gradient{}, &dummy)
if err.Error() != pserver.Uninitialized { if err.Error() != pserver.Uninitialized {
t.FailNow() t.FailNow()
} }
...@@ -98,13 +95,14 @@ func TestUninitialized(t *testing.T) { ...@@ -98,13 +95,14 @@ func TestUninitialized(t *testing.T) {
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
ch := make(chan struct{}, 2) ch := make(chan struct{}, 2)
errCh := make(chan error, 2)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
var param pserver.Parameter var param pserver.Parameter
err := s.GetParam("param_a", &param) err := s.GetParam("param_a", &param)
if err != nil { if err != nil {
t.FailNow() errCh <- err
} }
wg.Done() wg.Done()
ch <- struct{}{} ch <- struct{}{}
...@@ -112,10 +110,9 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -112,10 +110,9 @@ func TestBlockUntilInitialized(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
var dummy int err := s.Save("", nil)
err := s.Save("", &dummy)
if err != nil { if err != nil {
t.FailNow() errCh <- err
} }
wg.Done() wg.Done()
ch <- struct{}{} ch <- struct{}{}
...@@ -127,6 +124,8 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -127,6 +124,8 @@ func TestBlockUntilInitialized(t *testing.T) {
case <-ch: case <-ch:
// some function returned before initialization is completed. // some function returned before initialization is completed.
t.FailNow() t.FailNow()
case <-errCh:
t.FailNow()
default: default:
} }
...@@ -134,13 +133,12 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -134,13 +133,12 @@ func TestBlockUntilInitialized(t *testing.T) {
p.Name = "param_a" p.Name = "param_a"
p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0} p.Content = []byte{1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0}
p.ElementType = pserver.Int32 p.ElementType = pserver.Int32
var dummy int err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, nil)
err := s.InitParam(pserver.ParameterWithConfig{Param: p, Config: nil}, &dummy)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
err = s.FinishInitParams(0, &dummy) err = s.FinishInitParams(0, nil)
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册