From e2fae1685de73772fb2b46ed67b4ddc0b897c83c Mon Sep 17 00:00:00 2001 From: Helin Wang Date: Fri, 19 May 2017 15:30:53 -0400 Subject: [PATCH] SendGrad will return error if pserver is not initialized. --- paddle/go/pserver/service.go | 7 ++++++- paddle/go/pserver/service_test.go | 22 ++++++++++------------ 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/paddle/go/pserver/service.go b/paddle/go/pserver/service.go index a009b456330..f43e59403a7 100644 --- a/paddle/go/pserver/service.go +++ b/paddle/go/pserver/service.go @@ -10,6 +10,7 @@ import ( type ElementType int var ErrAlreadyInitialized = errors.New("pserver already initialized") +var ErrUninitialized = errors.New("pserver not fully initialized") // Supported element types const ( @@ -111,7 +112,11 @@ func (s *Service) FinishInitParams(dummy0 int, dummy1 *int) error { // SendGrads sends gradients to parameter servers for parameter // optimization. func (s *Service) SendGrads(grads []Gradient, dummy *int) error { - <-s.initialized + select { + case <-s.initialized: + default: + return ErrUninitialized + } count := len(grads) if count == 0 { diff --git a/paddle/go/pserver/service_test.go b/paddle/go/pserver/service_test.go index 6aa7f47c747..10185bd0f20 100644 --- a/paddle/go/pserver/service_test.go +++ b/paddle/go/pserver/service_test.go @@ -108,9 +108,18 @@ func TestMultipleInit(t *testing.T) { } } +func TestUninitialized(t *testing.T) { + s := pserver.NewService() + var dummy int + err := s.SendGrads(nil, &dummy) + if err != pserver.ErrUninitialized { + t.FailNow() + } +} + func TestBlockUntilInitialized(t *testing.T) { s := pserver.NewService() - ch := make(chan struct{}, 3) + ch := make(chan struct{}, 2) var wg sync.WaitGroup wg.Add(1) go func() { @@ -134,17 +143,6 @@ func TestBlockUntilInitialized(t *testing.T) { ch <- struct{}{} }() - wg.Add(1) - go func() { - var dummy int - err := s.SendGrads(nil, &dummy) - if err != nil { - t.FailNow() - } - wg.Done() - ch <- struct{}{} - }() - var dummy int err := s.BeginInitParams(nil, &dummy) if err != nil { -- GitLab