提交 e2fae168 编写于 作者: H Helin Wang

SendGrad will return error if pserver is not initialized.

上级 ea18f2ee
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
type ElementType int type ElementType int
var ErrAlreadyInitialized = errors.New("pserver already initialized") var ErrAlreadyInitialized = errors.New("pserver already initialized")
var ErrUninitialized = errors.New("pserver not fully initialized")
// Supported element types // Supported element types
const ( const (
...@@ -111,7 +112,11 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { ...@@ -111,7 +112,11 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error {
// SendGrads sends gradients to parameter servers for parameter // SendGrads sends gradients to parameter servers for parameter
// optimization. // optimization.
func (s *Service) SendGrads(grads []Gradient, dummy *int) error { func (s *Service) SendGrads(grads []Gradient, dummy *int) error {
<-s.initialized select {
case <-s.initialized:
default:
return ErrUninitialized
}
count := len(grads) count := len(grads)
if count == 0 { if count == 0 {
......
...@@ -108,9 +108,18 @@ func TestMultipleInit(t *testing.T) { ...@@ -108,9 +108,18 @@ func TestMultipleInit(t *testing.T) {
} }
} }
func TestUninitialized(t *testing.T) {
s := pserver.NewService()
var dummy int
err := s.SendGrads(nil, &dummy)
if err != pserver.ErrUninitialized {
t.FailNow()
}
}
func TestBlockUntilInitialized(t *testing.T) { func TestBlockUntilInitialized(t *testing.T) {
s := pserver.NewService() s := pserver.NewService()
ch := make(chan struct{}, 3) ch := make(chan struct{}, 2)
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(1) wg.Add(1)
go func() { go func() {
...@@ -134,17 +143,6 @@ func TestBlockUntilInitialized(t *testing.T) { ...@@ -134,17 +143,6 @@ func TestBlockUntilInitialized(t *testing.T) {
ch <- struct{}{} ch <- struct{}{}
}() }()
wg.Add(1)
go func() {
var dummy int
err := s.SendGrads(nil, &dummy)
if err != nil {
t.FailNow()
}
wg.Done()
ch <- struct{}{}
}()
var dummy int var dummy int
err := s.BeginInitParams(nil, &dummy) err := s.BeginInitParams(nil, &dummy)
if err != nil { if err != nil {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册