diff --git a/paddle/go/pserver/service.go b/paddle/go/pserver/service.go index a009b4563309d8f561deb252ac3e1a8a68511791..f43e59403a71cb5bed2187c2f2f80465642a5c65 100644 --- a/paddle/go/pserver/service.go +++ b/paddle/go/pserver/service.go @@ -10,6 +10,7 @@ import ( type ElementType int var ErrAlreadyInitialized = errors.New("pserver already initialized") +var ErrUninitialized = errors.New("pserver not fully initialized") // Supported element types const ( @@ -111,7 +112,11 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { // SendGrads sends gradients to parameter servers for parameter // optimization. func (s *Service) SendGrads(grads []Gradient, dummy *int) error { - <-s.initialized + select { + case <-s.initialized: + default: + return ErrUninitialized + } count := len(grads) if count == 0 { diff --git a/paddle/go/pserver/service_test.go b/paddle/go/pserver/service_test.go index 6aa7f47c7475ca27f5d1155609349c3a73f19827..10185bd0f2096bd85ff7d0fb688a4aa820e5308c 100644 --- a/paddle/go/pserver/service_test.go +++ b/paddle/go/pserver/service_test.go @@ -108,9 +108,18 @@ func TestMultipleInit(t *testing.T) { } } +func TestUninitialized(t *testing.T) { + s := pserver.NewService() + var dummy int + err := s.SendGrads(nil, &dummy) + if err != pserver.ErrUninitialized { + t.FailNow() + } +} + func TestBlockUntilInitialized(t *testing.T) { s := pserver.NewService() - ch := make(chan struct{}, 3) + ch := make(chan struct{}, 2) var wg sync.WaitGroup wg.Add(1) go func() { @@ -134,17 +143,6 @@ func TestBlockUntilInitialized(t *testing.T) { ch <- struct{}{} }() - wg.Add(1) - go func() { - var dummy int - err := s.SendGrads(nil, &dummy) - if err != nil { - t.FailNow() - } - wg.Done() - ch <- struct{}{} - }() - var dummy int err := s.BeginInitParams(nil, &dummy) if err != nil {