diff --git a/paddle/go/pserver/service.go b/paddle/go/pserver/service.go index 22f6cdf40d49a43b5843808afdbc724509877644..47a862c5ad2c2ec3f777a7237d870bd8c8e335f2 100644 --- a/paddle/go/pserver/service.go +++ b/paddle/go/pserver/service.go @@ -109,6 +109,11 @@ func (s *Service) SendGrads(grads []Gradient, dummy *int) error { return ErrUnintialized } + count := len(grads) + if count == 0 { + return nil + } + s.mu.Lock() defer s.mu.Unlock() @@ -118,16 +123,25 @@ func (s *Service) SendGrads(grads []Gradient, dummy *int) error { } } - var wg sync.WaitGroup + errCh := make(chan error, count) for _, g := range grads { - wg.Add(1) go func(p Parameter, g Gradient) { - s.opt.UpdateParameter(p, g) - wg.Done() + err := s.opt.UpdateParameter(p, g) + errCh <- err }(s.paramMap[g.Name], g) } - wg.Wait() + recv := 0 + for err := range errCh { + if err != nil { + return err + } + + recv++ + if recv == count { + break + } + } return nil }